Skip to content

Commit e1b10ce

Browse files
committed
add charts galore
1 parent 3e39a93 commit e1b10ce

File tree

1 file changed

+189
-11
lines changed

1 file changed

+189
-11
lines changed

llm-complete-guide/steps/populate_index.py

Lines changed: 189 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import json
2323
import logging
2424
import math
25-
from typing import Annotated
25+
from typing import Annotated, Dict, List, Tuple
2626

27+
from PIL import Image, ImageDraw, ImageFont
2728
from constants import (
2829
CHUNK_OVERLAP,
2930
CHUNK_SIZE,
@@ -40,12 +41,184 @@
4041
logger = logging.getLogger(__name__)
4142

4243

44+
def extract_docs_stats(total_documents: int, split_docs: List[Document]) -> Dict[str, Dict[str, int]]:
45+
"""Extracts statistics about the document chunks.
46+
47+
Args:
48+
total_documents (int): The total number of original documents before splitting.
49+
split_docs (List[Document]): The list of document chunks after splitting.
50+
51+
Returns:
52+
Dict[str, Dict[str, int]]: A dictionary containing two sub-dictionaries:
53+
- document_stats: Contains statistics about the chunks including:
54+
- total_documents: Number of original documents
55+
- total_chunks: Number of chunks after splitting
56+
- avg_chunk_size: Average size of chunks in characters
57+
- min_chunk_size: Size of smallest chunk in characters
58+
- max_chunk_size: Size of largest chunk in characters
59+
- chunks_per_section: Maps each document section to number of chunks it contains
60+
"""
61+
total_documents = total_documents
62+
total_chunks = len(split_docs)
63+
chunk_sizes = [len(doc.page_content) for doc in split_docs]
64+
avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes)
65+
min_chunk_size = min(chunk_sizes)
66+
max_chunk_size = max(chunk_sizes)
67+
chunks_per_section = {}
68+
for doc in split_docs:
69+
section = doc.parent_section
70+
if section not in chunks_per_section:
71+
chunks_per_section[section] = 0
72+
chunks_per_section[section] += 1
73+
74+
return {
75+
"document_stats": {
76+
"total_documents": total_documents,
77+
"total_chunks": total_chunks,
78+
"avg_chunk_size": avg_chunk_size,
79+
"min_chunk_size": min_chunk_size,
80+
"max_chunk_size": max_chunk_size
81+
},
82+
"chunks_per_section": chunks_per_section
83+
}
84+
85+
86+
87+
def create_charts(stats: Dict[str, Dict[str, int]]) -> Image.Image:
88+
"""Creates a combined image containing a histogram of chunk sizes and a bar chart of chunk counts per section.
89+
90+
Args:
91+
stats (Dict[str, Dict[str, int]]): A dictionary containing the extracted statistics.
92+
93+
Returns:
94+
Image.Image: A combined image containing the histogram and bar chart.
95+
"""
96+
document_stats = stats["document_stats"]
97+
chunks_per_section = stats["chunks_per_section"]
98+
99+
# Create a new image with a white background
100+
image_width = 800
101+
image_height = 600
102+
image = Image.new("RGB", (image_width, image_height), color="white")
103+
draw = ImageDraw.Draw(image)
104+
105+
# Draw the histogram of chunk sizes
106+
histogram_width = 400
107+
histogram_height = 300
108+
histogram_data = [document_stats["min_chunk_size"], document_stats["avg_chunk_size"], document_stats["max_chunk_size"]]
109+
histogram_labels = ["Min", "Avg", "Max"]
110+
histogram_x = 50
111+
histogram_y = 50
112+
draw_histogram(draw, histogram_x, histogram_y, histogram_width, histogram_height, histogram_data, histogram_labels)
113+
114+
# Draw the bar chart of chunk counts per section
115+
bar_chart_width = 400
116+
bar_chart_height = 300
117+
bar_chart_data = list(chunks_per_section.values())
118+
bar_chart_labels = list(chunks_per_section.keys())
119+
bar_chart_x = 450
120+
bar_chart_y = 50
121+
draw_bar_chart(draw, bar_chart_x, bar_chart_y, bar_chart_width, bar_chart_height, bar_chart_data, bar_chart_labels)
122+
123+
# Add a title to the combined image
124+
title_text = "Document Chunk Statistics"
125+
title_font = ImageFont.truetype("arial.ttf", 24)
126+
title_width, title_height = draw.textsize(title_text, font=title_font)
127+
title_x = (image_width - title_width) // 2
128+
title_y = 10
129+
draw.text((title_x, title_y), title_text, font=title_font, fill="black")
130+
131+
return image
132+
133+
def draw_histogram(draw: ImageDraw.Draw, x: int, y: int, width: int, height: int, data: List[int], labels: List[str]) -> None:
134+
"""Draws a histogram chart showing the distribution of chunk sizes.
135+
136+
Args:
137+
draw (ImageDraw.Draw): The ImageDraw object to draw on
138+
x (int): The x coordinate of the top-left corner of the histogram
139+
y (int): The y coordinate of the top-left corner of the histogram
140+
width (int): The width of the histogram in pixels
141+
height (int): The height of the histogram in pixels
142+
data (List[int]): The values to plot in the histogram
143+
labels (List[str]): The labels for each bar in the histogram
144+
"""
145+
# Calculate the maximum value in the data
146+
max_value = max(data)
147+
148+
# Calculate the bar width and spacing
149+
bar_width = width // len(data)
150+
bar_spacing = 10
151+
152+
# Draw the bars
153+
for i, value in enumerate(data):
154+
bar_height = (value / max_value) * height
155+
bar_x = x + i * (bar_width + bar_spacing)
156+
bar_y = y + height - bar_height
157+
draw.rectangle([(bar_x, bar_y), (bar_x + bar_width, y + height)], fill="blue")
158+
159+
# Draw the label below the bar
160+
label_text = labels[i]
161+
label_font = ImageFont.truetype("arial.ttf", 12)
162+
label_width, label_height = draw.textsize(label_text, font=label_font)
163+
label_x = bar_x + (bar_width - label_width) // 2
164+
label_y = y + height + 5
165+
draw.text((label_x, label_y), label_text, font=label_font, fill="black")
166+
167+
# Draw the title above the histogram
168+
title_text = "Chunk Size Distribution"
169+
title_font = ImageFont.truetype("arial.ttf", 16)
170+
title_width, title_height = draw.textsize(title_text, font=title_font)
171+
title_x = x + (width - title_width) // 2
172+
title_y = y - title_height - 10
173+
draw.text((title_x, title_y), title_text, font=title_font, fill="black")
174+
175+
def draw_bar_chart(draw: ImageDraw.Draw, x: int, y: int, width: int, height: int, data: List[int], labels: List[str]) -> None:
176+
"""Draws a bar chart showing the number of chunks per section.
177+
178+
Args:
179+
draw (ImageDraw.Draw): The ImageDraw object to draw on
180+
x (int): The x coordinate of the top-left corner of the bar chart
181+
y (int): The y coordinate of the top-left corner of the bar chart
182+
width (int): The width of the bar chart in pixels
183+
height (int): The height of the bar chart in pixels
184+
data (List[int]): The values to plot in the bar chart
185+
labels (List[str]): The labels for each bar in the chart
186+
"""
187+
# Calculate the maximum value in the data
188+
max_value = max(data)
189+
190+
# Calculate the bar width and spacing
191+
bar_width = width // len(data)
192+
bar_spacing = 10
193+
194+
# Draw the bars
195+
for i, value in enumerate(data):
196+
bar_height = (value / max_value) * height
197+
bar_x = x + i * (bar_width + bar_spacing)
198+
bar_y = y + height - bar_height
199+
draw.rectangle([(bar_x, bar_y), (bar_x + bar_width, y + height)], fill="green")
200+
201+
# Draw the label below the bar
202+
label_text = labels[i]
203+
label_font = ImageFont.truetype("arial.ttf", 12)
204+
label_width, label_height = draw.textsize(label_text, font=label_font)
205+
label_x = bar_x + (bar_width - label_width) // 2
206+
label_y = y + height + 5
207+
draw.text((label_x, label_y), label_text, font=label_font, fill="black")
208+
209+
# Draw the title above the bar chart
210+
title_text = "Chunk Counts per Section"
211+
title_font = ImageFont.truetype("arial.ttf", 16)
212+
title_width, title_height = draw.textsize(title_text, font=title_font)
213+
title_x = x + (width - title_width) // 2
214+
title_y = y - title_height - 10
215+
draw.text((title_x, title_y), title_text, font=title_font, fill="black")
216+
43217
@step
44218
def preprocess_documents(
45219
documents: str,
46-
) -> Annotated[str, ArtifactConfig(name="split_chunks")]:
47-
"""
48-
Preprocesses a JSON string of documents by splitting them into chunks.
220+
) -> Tuple[Annotated[str, ArtifactConfig(name="split_chunks")], Annotated[Image.Image, ArtifactConfig(name="doc_stats_chart")]]:
221+
"""Preprocesses a JSON string of documents by splitting them into chunks.
49222
50223
Args:
51224
documents (str): A JSON string containing a list of documents to be preprocessed.
@@ -65,17 +238,22 @@ def preprocess_documents(
65238
},
66239
)
67240

68-
# Parse the JSON string into a list of Document objects
69-
document_list = [Document(**doc) for doc in json.loads(documents)]
70-
71-
split_docs = split_documents(
241+
document_list: List[Document] = [Document(**doc) for doc in json.loads(documents)]
242+
split_docs: List[Document] = split_documents(
72243
document_list, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
73244
)
74245

75-
# Convert the list of Document objects back to a JSON string
76-
split_docs_json = json.dumps([doc.__dict__ for doc in split_docs])
246+
stats: Dict[str, Dict[str, int]] = extract_docs_stats(len(document_list), split_docs)
247+
chart: Image.Image = create_charts(stats)
248+
249+
log_artifact_metadata(
250+
artifact_name="split_chunks",
251+
metadata=stats,
252+
)
253+
254+
split_docs_json: str = json.dumps([doc.__dict__ for doc in split_docs])
77255

78-
return split_docs_json
256+
return split_docs_json, chart
79257
except Exception as e:
80258
logger.error(f"Error in preprocess_documents: {e}")
81259
raise

0 commit comments

Comments
 (0)