2222import json
2323import logging
2424import math
25- from typing import Annotated
25+ from typing import Annotated , Dict , List , Tuple
2626
27+ from PIL import Image , ImageDraw , ImageFont
2728from constants import (
2829 CHUNK_OVERLAP ,
2930 CHUNK_SIZE ,
4041logger = logging .getLogger (__name__ )
4142
4243
44+ def extract_docs_stats (total_documents : int , split_docs : List [Document ]) -> Dict [str , Dict [str , int ]]:
45+ """Extracts statistics about the document chunks.
46+
47+ Args:
48+ total_documents (int): The total number of original documents before splitting.
49+ split_docs (List[Document]): The list of document chunks after splitting.
50+
51+ Returns:
52+ Dict[str, Dict[str, int]]: A dictionary containing two sub-dictionaries:
53+ - document_stats: Contains statistics about the chunks including:
54+ - total_documents: Number of original documents
55+ - total_chunks: Number of chunks after splitting
56+ - avg_chunk_size: Average size of chunks in characters
57+ - min_chunk_size: Size of smallest chunk in characters
58+ - max_chunk_size: Size of largest chunk in characters
59+ - chunks_per_section: Maps each document section to number of chunks it contains
60+ """
61+ total_documents = total_documents
62+ total_chunks = len (split_docs )
63+ chunk_sizes = [len (doc .page_content ) for doc in split_docs ]
64+ avg_chunk_size = sum (chunk_sizes ) / len (chunk_sizes )
65+ min_chunk_size = min (chunk_sizes )
66+ max_chunk_size = max (chunk_sizes )
67+ chunks_per_section = {}
68+ for doc in split_docs :
69+ section = doc .parent_section
70+ if section not in chunks_per_section :
71+ chunks_per_section [section ] = 0
72+ chunks_per_section [section ] += 1
73+
74+ return {
75+ "document_stats" : {
76+ "total_documents" : total_documents ,
77+ "total_chunks" : total_chunks ,
78+ "avg_chunk_size" : avg_chunk_size ,
79+ "min_chunk_size" : min_chunk_size ,
80+ "max_chunk_size" : max_chunk_size
81+ },
82+ "chunks_per_section" : chunks_per_section
83+ }
84+
85+
86+
87+ def create_charts (stats : Dict [str , Dict [str , int ]]) -> Image .Image :
88+ """Creates a combined image containing a histogram of chunk sizes and a bar chart of chunk counts per section.
89+
90+ Args:
91+ stats (Dict[str, Dict[str, int]]): A dictionary containing the extracted statistics.
92+
93+ Returns:
94+ Image.Image: A combined image containing the histogram and bar chart.
95+ """
96+ document_stats = stats ["document_stats" ]
97+ chunks_per_section = stats ["chunks_per_section" ]
98+
99+ # Create a new image with a white background
100+ image_width = 800
101+ image_height = 600
102+ image = Image .new ("RGB" , (image_width , image_height ), color = "white" )
103+ draw = ImageDraw .Draw (image )
104+
105+ # Draw the histogram of chunk sizes
106+ histogram_width = 400
107+ histogram_height = 300
108+ histogram_data = [document_stats ["min_chunk_size" ], document_stats ["avg_chunk_size" ], document_stats ["max_chunk_size" ]]
109+ histogram_labels = ["Min" , "Avg" , "Max" ]
110+ histogram_x = 50
111+ histogram_y = 50
112+ draw_histogram (draw , histogram_x , histogram_y , histogram_width , histogram_height , histogram_data , histogram_labels )
113+
114+ # Draw the bar chart of chunk counts per section
115+ bar_chart_width = 400
116+ bar_chart_height = 300
117+ bar_chart_data = list (chunks_per_section .values ())
118+ bar_chart_labels = list (chunks_per_section .keys ())
119+ bar_chart_x = 450
120+ bar_chart_y = 50
121+ draw_bar_chart (draw , bar_chart_x , bar_chart_y , bar_chart_width , bar_chart_height , bar_chart_data , bar_chart_labels )
122+
123+ # Add a title to the combined image
124+ title_text = "Document Chunk Statistics"
125+ title_font = ImageFont .truetype ("arial.ttf" , 24 )
126+ title_width , title_height = draw .textsize (title_text , font = title_font )
127+ title_x = (image_width - title_width ) // 2
128+ title_y = 10
129+ draw .text ((title_x , title_y ), title_text , font = title_font , fill = "black" )
130+
131+ return image
132+
133+ def draw_histogram (draw : ImageDraw .Draw , x : int , y : int , width : int , height : int , data : List [int ], labels : List [str ]) -> None :
134+ """Draws a histogram chart showing the distribution of chunk sizes.
135+
136+ Args:
137+ draw (ImageDraw.Draw): The ImageDraw object to draw on
138+ x (int): The x coordinate of the top-left corner of the histogram
139+ y (int): The y coordinate of the top-left corner of the histogram
140+ width (int): The width of the histogram in pixels
141+ height (int): The height of the histogram in pixels
142+ data (List[int]): The values to plot in the histogram
143+ labels (List[str]): The labels for each bar in the histogram
144+ """
145+ # Calculate the maximum value in the data
146+ max_value = max (data )
147+
148+ # Calculate the bar width and spacing
149+ bar_width = width // len (data )
150+ bar_spacing = 10
151+
152+ # Draw the bars
153+ for i , value in enumerate (data ):
154+ bar_height = (value / max_value ) * height
155+ bar_x = x + i * (bar_width + bar_spacing )
156+ bar_y = y + height - bar_height
157+ draw .rectangle ([(bar_x , bar_y ), (bar_x + bar_width , y + height )], fill = "blue" )
158+
159+ # Draw the label below the bar
160+ label_text = labels [i ]
161+ label_font = ImageFont .truetype ("arial.ttf" , 12 )
162+ label_width , label_height = draw .textsize (label_text , font = label_font )
163+ label_x = bar_x + (bar_width - label_width ) // 2
164+ label_y = y + height + 5
165+ draw .text ((label_x , label_y ), label_text , font = label_font , fill = "black" )
166+
167+ # Draw the title above the histogram
168+ title_text = "Chunk Size Distribution"
169+ title_font = ImageFont .truetype ("arial.ttf" , 16 )
170+ title_width , title_height = draw .textsize (title_text , font = title_font )
171+ title_x = x + (width - title_width ) // 2
172+ title_y = y - title_height - 10
173+ draw .text ((title_x , title_y ), title_text , font = title_font , fill = "black" )
174+
175+ def draw_bar_chart (draw : ImageDraw .Draw , x : int , y : int , width : int , height : int , data : List [int ], labels : List [str ]) -> None :
176+ """Draws a bar chart showing the number of chunks per section.
177+
178+ Args:
179+ draw (ImageDraw.Draw): The ImageDraw object to draw on
180+ x (int): The x coordinate of the top-left corner of the bar chart
181+ y (int): The y coordinate of the top-left corner of the bar chart
182+ width (int): The width of the bar chart in pixels
183+ height (int): The height of the bar chart in pixels
184+ data (List[int]): The values to plot in the bar chart
185+ labels (List[str]): The labels for each bar in the chart
186+ """
187+ # Calculate the maximum value in the data
188+ max_value = max (data )
189+
190+ # Calculate the bar width and spacing
191+ bar_width = width // len (data )
192+ bar_spacing = 10
193+
194+ # Draw the bars
195+ for i , value in enumerate (data ):
196+ bar_height = (value / max_value ) * height
197+ bar_x = x + i * (bar_width + bar_spacing )
198+ bar_y = y + height - bar_height
199+ draw .rectangle ([(bar_x , bar_y ), (bar_x + bar_width , y + height )], fill = "green" )
200+
201+ # Draw the label below the bar
202+ label_text = labels [i ]
203+ label_font = ImageFont .truetype ("arial.ttf" , 12 )
204+ label_width , label_height = draw .textsize (label_text , font = label_font )
205+ label_x = bar_x + (bar_width - label_width ) // 2
206+ label_y = y + height + 5
207+ draw .text ((label_x , label_y ), label_text , font = label_font , fill = "black" )
208+
209+ # Draw the title above the bar chart
210+ title_text = "Chunk Counts per Section"
211+ title_font = ImageFont .truetype ("arial.ttf" , 16 )
212+ title_width , title_height = draw .textsize (title_text , font = title_font )
213+ title_x = x + (width - title_width ) // 2
214+ title_y = y - title_height - 10
215+ draw .text ((title_x , title_y ), title_text , font = title_font , fill = "black" )
216+
43217@step
44218def preprocess_documents (
45219 documents : str ,
46- ) -> Annotated [str , ArtifactConfig (name = "split_chunks" )]:
47- """
48- Preprocesses a JSON string of documents by splitting them into chunks.
220+ ) -> Tuple [Annotated [str , ArtifactConfig (name = "split_chunks" )], Annotated [Image .Image , ArtifactConfig (name = "doc_stats_chart" )]]:
221+ """Preprocesses a JSON string of documents by splitting them into chunks.
49222
50223 Args:
51224 documents (str): A JSON string containing a list of documents to be preprocessed.
@@ -65,17 +238,22 @@ def preprocess_documents(
65238 },
66239 )
67240
68- # Parse the JSON string into a list of Document objects
69- document_list = [Document (** doc ) for doc in json .loads (documents )]
70-
71- split_docs = split_documents (
241+ document_list : List [Document ] = [Document (** doc ) for doc in json .loads (documents )]
242+ split_docs : List [Document ] = split_documents (
72243 document_list , chunk_size = CHUNK_SIZE , chunk_overlap = CHUNK_OVERLAP
73244 )
74245
75- # Convert the list of Document objects back to a JSON string
76- split_docs_json = json .dumps ([doc .__dict__ for doc in split_docs ])
246+ stats : Dict [str , Dict [str , int ]] = extract_docs_stats (len (document_list ), split_docs )
247+ chart : Image .Image = create_charts (stats )
248+
249+ log_artifact_metadata (
250+ artifact_name = "split_chunks" ,
251+ metadata = stats ,
252+ )
253+
254+ split_docs_json : str = json .dumps ([doc .__dict__ for doc in split_docs ])
77255
78- return split_docs_json
256+ return split_docs_json , chart
79257 except Exception as e :
80258 logger .error (f"Error in preprocess_documents: { e } " )
81259 raise
0 commit comments