Skip to content

Commit 4d96ae3

Browse files
author
marwan37
committed
rename ocr_model_utils to ocr_processing
1 parent 6a0295f commit 4d96ae3

File tree

1 file changed

+328
-0
lines changed

1 file changed

+328
-0
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2025. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Utility functions for OCR operations across different models."""
17+
18+
import contextlib
19+
import json
20+
import os
21+
import re
22+
import statistics
23+
import time
24+
from typing import Any, Dict, List, Optional
25+
26+
import polars as pl
27+
from dotenv import load_dotenv
28+
from zenml import log_metadata
29+
from zenml.logger import get_logger
30+
31+
from utils.encode_image import encode_image
32+
from utils.model_configs import ModelConfig
33+
from utils.prompt import ImageDescription, get_prompt
34+
35+
load_dotenv()
36+
logger = get_logger(__name__)
37+
38+
39+
def try_extract_json_from_response(response: Any) -> Dict:
40+
"""Extract JSON from a response, handling various formats.
41+
42+
Args:
43+
response: The response which could be string, dict, or object
44+
45+
Returns:
46+
Dict with extracted data
47+
"""
48+
# If already a dict with raw_text, return it
49+
if isinstance(response, dict) and "raw_text" in response:
50+
return response
51+
52+
# Convert to string if it's an object with content
53+
response_text = ""
54+
if hasattr(response, "choices") and len(response.choices) > 0:
55+
if hasattr(response.choices[0], "message") and hasattr(
56+
response.choices[0].message, "content"
57+
):
58+
response_text = response.choices[0].message.content
59+
elif isinstance(response, str):
60+
response_text = response
61+
elif hasattr(response, "raw_text"):
62+
# This handles the ImageDescription object case
63+
return {"raw_text": response.raw_text, "confidence": getattr(response, "confidence", None)}
64+
65+
# Try to extract JSON from the text
66+
JSON_PATTERN = re.compile(r"```json\n(.*?)```", re.DOTALL)
67+
DIRECT_JSON_PATTERN = re.compile(r"\{[^}]*\}", re.DOTALL)
68+
69+
try:
70+
if match := JSON_PATTERN.search(response_text):
71+
json_results = match.group(1)
72+
with contextlib.suppress(json.JSONDecodeError):
73+
return json.loads(json_results)
74+
if match := DIRECT_JSON_PATTERN.search(response_text):
75+
json_text = match.group(0)
76+
with contextlib.suppress(json.JSONDecodeError):
77+
return json.loads(json_text)
78+
79+
# If we get here, no JSON could be extracted, so use the text as raw_text
80+
return {"raw_text": response_text, "confidence": None}
81+
except Exception as e:
82+
# Fallback for any other errors
83+
return {"raw_text": f"Error: {str(e)}", "confidence": 0.0, "success": False}
84+
85+
86+
def log_image_metadata(
87+
prefix: str,
88+
index: int,
89+
image_name: str,
90+
processing_time: float,
91+
text_length: int,
92+
confidence: float,
93+
):
94+
"""Log metadata for a processed image.
95+
96+
Args:
97+
prefix: The model prefix (openai, mistral, etc.)
98+
index: Image index
99+
image_name: Name of the image file
100+
processing_time: Processing time in seconds
101+
text_length: Length of extracted text
102+
confidence: Confidence score
103+
"""
104+
log_metadata(
105+
metadata={
106+
f"{prefix}_image_{index}": {
107+
"image_name": image_name,
108+
"processing_time_seconds": processing_time,
109+
"text_length": text_length,
110+
"confidence": confidence,
111+
}
112+
}
113+
)
114+
115+
116+
def log_error_metadata(
117+
prefix: str,
118+
index: int,
119+
image_name: str,
120+
error: str,
121+
):
122+
"""Log error metadata for a failed image processing.
123+
124+
Args:
125+
prefix: The model prefix (openai, mistral, etc.)
126+
index: Image index
127+
image_name: Name of the image file
128+
error: Error message
129+
"""
130+
log_metadata(
131+
metadata={
132+
f"{prefix}_error_image_{index}": {
133+
"image_name": image_name,
134+
"error": error,
135+
}
136+
}
137+
)
138+
139+
140+
def log_summary_metadata(
141+
prefix: str,
142+
model_name: str,
143+
images_count: int,
144+
processing_times: List[float],
145+
confidence_scores: List[float],
146+
):
147+
"""Log summary metadata for all processed images.
148+
149+
Args:
150+
prefix: The model prefix (openai, mistral, etc.)
151+
model_name: Name of the model
152+
images_count: Number of images processed
153+
processing_times: List of processing times
154+
confidence_scores: List of confidence scores
155+
"""
156+
avg_time = statistics.mean(processing_times)
157+
max_time = max(processing_times)
158+
min_time = min(processing_times)
159+
160+
avg_confidence = 0.0
161+
if confidence_scores:
162+
avg_confidence = statistics.mean(confidence_scores)
163+
164+
log_metadata(
165+
metadata={
166+
f"{prefix}_ocr_summary": {
167+
"model": model_name,
168+
"images_processed": images_count,
169+
"avg_processing_time": avg_time,
170+
"min_processing_time": min_time,
171+
"max_processing_time": max_time,
172+
"avg_confidence": avg_confidence,
173+
"total_processing_time": sum(processing_times),
174+
}
175+
}
176+
)
177+
178+
179+
def process_images_with_model(
180+
model_config: ModelConfig,
181+
images: List[str],
182+
custom_prompt: Optional[str] = None,
183+
batch_size: int = 5,
184+
) -> pl.DataFrame:
185+
"""Process images with a specific model configuration.
186+
187+
Args:
188+
model_config: Model configuration
189+
images: List of image paths
190+
custom_prompt: Optional custom prompt
191+
batch_size: Number of images to process in parallel (default: 5)
192+
193+
Returns:
194+
DataFrame with OCR results
195+
"""
196+
from concurrent.futures import ThreadPoolExecutor
197+
198+
from tqdm import tqdm
199+
200+
model_name = model_config.name
201+
prefix = model_config.prefix
202+
display = model_config.display
203+
prompt = custom_prompt if custom_prompt else get_prompt()
204+
205+
logger.info(f"Running {display} OCR with model: {model_name}")
206+
logger.info(f"Processing {len(images)} images with batch size: {batch_size}")
207+
208+
results_list = []
209+
processing_times = []
210+
confidence_scores = []
211+
212+
def process_single_image(args):
213+
i, image_path = args
214+
start_time = time.time()
215+
image_name = os.path.basename(image_path)
216+
217+
try:
218+
content_type, image_base64 = encode_image(image_path)
219+
220+
result_json = model_config.process_image(prompt, image_base64, content_type)
221+
222+
raw_text = result_json.get("raw_text", "No text found")
223+
confidence = result_json.get("confidence", model_config.default_confidence)
224+
if confidence is None:
225+
confidence = model_config.default_confidence
226+
227+
processing_time = time.time() - start_time
228+
229+
result = {
230+
"id": i,
231+
"image_name": image_name,
232+
"raw_text": raw_text,
233+
"processing_time": processing_time,
234+
"confidence": confidence,
235+
}
236+
237+
log_image_metadata(
238+
prefix=prefix,
239+
index=i,
240+
image_name=image_name,
241+
processing_time=processing_time,
242+
text_length=len(result["raw_text"]),
243+
confidence=confidence,
244+
)
245+
246+
logger.info(
247+
f"{display} OCR [{i + 1}/{len(images)}]: {image_name} - "
248+
f"{len(result['raw_text'])} chars, "
249+
f"confidence: {confidence:.2f}, "
250+
f"{processing_time:.2f} seconds"
251+
)
252+
253+
return {
254+
"result": result,
255+
"processing_time": processing_time,
256+
"confidence": confidence,
257+
"success": True,
258+
}
259+
260+
except Exception as e:
261+
error_message = f"An unexpected error occurred on image {image_name}: {str(e)}"
262+
logger.error(error_message)
263+
processing_time = time.time() - start_time
264+
265+
log_error_metadata(
266+
prefix=prefix,
267+
index=i,
268+
image_name=image_name,
269+
error=str(e),
270+
)
271+
272+
return {
273+
"result": {
274+
"id": i,
275+
"image_name": image_name,
276+
"raw_text": f"Error: Failed to extract text - {str(e)}",
277+
"processing_time": processing_time,
278+
"confidence": 0.0,
279+
"error": error_message,
280+
},
281+
"processing_time": processing_time,
282+
"confidence": 0.0,
283+
"success": False,
284+
}
285+
286+
effective_batch_size = min(batch_size, len(images))
287+
max_workers = min(effective_batch_size, 10)
288+
289+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
290+
with tqdm(total=len(images), desc=f"Processing with {display}") as pbar:
291+
image_batches = [
292+
images[i : i + effective_batch_size]
293+
for i in range(0, len(images), effective_batch_size)
294+
]
295+
296+
for batch_index, batch in enumerate(image_batches):
297+
logger.info(
298+
f"Processing batch {batch_index + 1}/{len(image_batches)} with {len(batch)} images"
299+
)
300+
301+
batch_indices = range(
302+
batch_index * effective_batch_size,
303+
batch_index * effective_batch_size + len(batch),
304+
)
305+
306+
batch_futures = list(executor.map(process_single_image, zip(batch_indices, batch)))
307+
308+
for result_dict in batch_futures:
309+
results_list.append(result_dict["result"])
310+
processing_times.append(result_dict["processing_time"])
311+
312+
if result_dict["success"]:
313+
confidence_scores.append(result_dict["confidence"])
314+
315+
pbar.update(1)
316+
317+
# Log summary statistics
318+
log_summary_metadata(
319+
prefix=prefix,
320+
model_name=model_name,
321+
images_count=len(images),
322+
processing_times=processing_times,
323+
confidence_scores=confidence_scores,
324+
)
325+
326+
# Convert to polars dataframe
327+
results_df = pl.DataFrame(results_list)
328+
return results_df

0 commit comments

Comments
 (0)