Skip to content

Commit 2b45149

Browse files
author
marwan37
committed
Add standalone script for quick OCR comparison without ZenML
1 parent c6b5ce6 commit 2b45149

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed

omni-reader/run_compare_ocr.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Module for running OCR comparison without using ZenML pipeline."""
2+
3+
import argparse
4+
import os
5+
import time
6+
from typing import Any, Dict, List, Optional
7+
8+
# For faster performance in interactive mode without ZenML overhead,
9+
# we implement the OCR functions directly here
10+
import instructor
11+
import polars as pl
12+
from dotenv import load_dotenv
13+
from litellm import completion
14+
from mistralai import Mistral
15+
from PIL import Image
16+
17+
from schemas.image_description import ImageDescription
18+
from utils.encode_image import encode_image
19+
from utils.metrics import compare_results
20+
from utils.prompt import get_prompt
21+
22+
load_dotenv()
23+
24+
25+
def run_gemma3_ocr_direct(
26+
image: str | Image.Image,
27+
custom_prompt: Optional[str] = None,
28+
) -> Dict[str, Any]:
29+
"""Extract text directly using gemma3 model.
30+
31+
Args:
32+
image: Path to image or PIL image
33+
custom_prompt: Optional custom prompt
34+
35+
Returns:
36+
Dict with extraction results
37+
"""
38+
start_time = time.time()
39+
content_type, image_base64 = encode_image(image)
40+
41+
client = instructor.from_litellm(completion)
42+
model_name = "ollama/gemma3:27b"
43+
44+
prompt = custom_prompt if custom_prompt else get_prompt()
45+
46+
try:
47+
response = client.chat.completions.create(
48+
model=model_name,
49+
response_model=ImageDescription,
50+
messages=[
51+
{
52+
"role": "user",
53+
"content": [
54+
{"type": "text", "text": prompt},
55+
{
56+
"type": "image_url",
57+
"image_url": f"data:{content_type};base64,{image_base64}",
58+
},
59+
],
60+
}
61+
],
62+
)
63+
64+
processing_time = time.time() - start_time
65+
66+
result = {
67+
"raw_text": response.raw_text if response.raw_text else "No text found",
68+
"description": response.description if response.description else "No description found",
69+
"entities": response.entities if response.entities else [],
70+
"processing_time": processing_time,
71+
"model": model_name,
72+
}
73+
74+
return result
75+
except Exception as e:
76+
error_message = f"An unexpected error occurred: {str(e)}"
77+
return {
78+
"raw_text": "Error: Failed to extract text",
79+
"description": "Error: Failed to extract description",
80+
"entities": [],
81+
"error": error_message,
82+
"processing_time": time.time() - start_time,
83+
"model": model_name,
84+
}
85+
86+
87+
def run_mistral_ocr_direct(
88+
image: str | Image.Image,
89+
custom_prompt: Optional[str] = None,
90+
) -> Dict[str, Any]:
91+
"""Extract text directly using mistral model.
92+
93+
Args:
94+
image: Path to image or PIL image
95+
custom_prompt: Optional custom prompt
96+
97+
Returns:
98+
Dict with extraction results
99+
"""
100+
start_time = time.time()
101+
content_type, image_base64 = encode_image(image)
102+
103+
mistral_client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
104+
client = instructor.from_mistral(mistral_client)
105+
106+
model_name = "pixtral-12b-2409"
107+
108+
prompt = custom_prompt if custom_prompt else get_prompt()
109+
110+
try:
111+
response = client.chat.completions.create(
112+
model=model_name,
113+
response_model=ImageDescription,
114+
messages=[
115+
{
116+
"role": "user",
117+
"content": [
118+
{"type": "text", "text": prompt},
119+
{
120+
"type": "image_url",
121+
"image_url": f"data:{content_type};base64,{image_base64}",
122+
},
123+
],
124+
}
125+
],
126+
)
127+
128+
print(f"Response: {response}")
129+
130+
processing_time = time.time() - start_time
131+
132+
result = {
133+
"raw_text": response.raw_text if response.raw_text else "No text found",
134+
"description": response.description if response.description else "No description found",
135+
"entities": response.entities if response.entities else [],
136+
"processing_time": processing_time,
137+
"model": model_name,
138+
}
139+
140+
return result
141+
except Exception as e:
142+
error_message = f"An unexpected error occurred: {str(e)}"
143+
return {
144+
"raw_text": "Error: Failed to extract text",
145+
"description": "Error: Failed to extract description",
146+
"entities": [],
147+
"error": error_message,
148+
"processing_time": time.time() - start_time,
149+
"model": model_name,
150+
}
151+
152+
153+
def run_ocr(
154+
image: str | Image.Image,
155+
model: str = "gemma3",
156+
custom_prompt: Optional[str] = None,
157+
) -> Dict[str, Any]:
158+
"""Run OCR using either Gemma3 or Mistral model.
159+
160+
Args:
161+
image: Path to image or PIL image
162+
model: Model to use ('gemma3' or 'mistral')
163+
custom_prompt: Optional custom prompt
164+
165+
Returns:
166+
Dict with extraction results
167+
"""
168+
if model.lower() == "gemma3":
169+
return run_gemma3_ocr_direct(image=image, custom_prompt=custom_prompt)
170+
else:
171+
return run_mistral_ocr_direct(image=image, custom_prompt=custom_prompt)
172+
173+
174+
def compare_models(
175+
image_paths: List[str],
176+
custom_prompt: Optional[str] = None,
177+
ground_truth_texts: Optional[List[str]] = None,
178+
) -> Dict[str, Any]:
179+
"""Compare Gemma3 and Mistral OCR capabilities on a list of images.
180+
181+
Args:
182+
image_paths: List of paths to images
183+
custom_prompt: Optional custom prompt to use for both models
184+
ground_truth_texts: Optional list of ground truth texts
185+
Returns:
186+
Dictionary with comparison results
187+
"""
188+
results = {
189+
"gemma_results": [],
190+
"mistral_results": [],
191+
"ground_truth": [],
192+
}
193+
194+
for i, image_path in enumerate(image_paths):
195+
image_name = os.path.basename(image_path)
196+
197+
print(f"Processing image {i + 1}/{len(image_paths)}: {image_name}")
198+
199+
# Run both models
200+
gemma_result = run_ocr(
201+
image=image_path,
202+
model="gemma3",
203+
custom_prompt=custom_prompt,
204+
)
205+
mistral_result = run_ocr(
206+
image=image_path,
207+
model="mistral",
208+
custom_prompt=custom_prompt,
209+
)
210+
211+
# Create entries for dataframes
212+
gemma_entry = {
213+
"id": i,
214+
"image_name": image_name,
215+
"gemma_text": gemma_result["raw_text"],
216+
"gemma_entities": ", ".join(gemma_result.get("entities", [])),
217+
"gemma_processing_time": gemma_result.get("processing_time", 0),
218+
}
219+
220+
mistral_entry = {
221+
"id": i,
222+
"image_name": image_name,
223+
"mistral_text": mistral_result["raw_text"],
224+
"mistral_entities": ", ".join(mistral_result.get("entities", [])),
225+
"mistral_processing_time": mistral_result.get("processing_time", 0),
226+
}
227+
228+
results["gemma_results"].append(gemma_entry)
229+
results["mistral_results"].append(mistral_entry)
230+
231+
# Add ground truth if available
232+
if ground_truth_texts and i < len(ground_truth_texts):
233+
results["ground_truth"].append(
234+
{
235+
"id": i,
236+
"image_name": image_name,
237+
"ground_truth_text": ground_truth_texts[i],
238+
}
239+
)
240+
241+
# Calculate metrics
242+
metrics = compare_results(
243+
ground_truth_texts[i],
244+
gemma_result["raw_text"],
245+
mistral_result["raw_text"],
246+
)
247+
print(f"Metrics for {image_name}:")
248+
for key, value in metrics.items():
249+
print(f" {key}: {value:.4f}")
250+
251+
return results
252+
253+
254+
if __name__ == "__main__":
255+
parser = argparse.ArgumentParser(description="Compare OCR models")
256+
parser.add_argument("--image", type=str, required=True, help="Path to image file")
257+
parser.add_argument(
258+
"--model",
259+
type=str,
260+
default="both",
261+
help="Model to use: 'gemma3', 'mistral', or 'both'",
262+
)
263+
parser.add_argument("--prompt", type=str, help="Custom prompt to use")
264+
265+
args = parser.parse_args()
266+
267+
if args.model.lower() == "both":
268+
start_time = time.time()
269+
gemma_result = run_ocr(args.image, "gemma3", args.prompt)
270+
mistral_result = run_ocr(args.image, "mistral", args.prompt)
271+
print("\nGemma3 results:")
272+
print(f"Text: {gemma_result['raw_text']}")
273+
print(f"Entities: {gemma_result.get('entities', [])}")
274+
print(f"Processing time: {gemma_result.get('processing_time', 0):.2f}s")
275+
276+
print("\nMistral results:")
277+
print(f"Text: {mistral_result['raw_text']}")
278+
print(f"Entities: {mistral_result.get('entities', [])}")
279+
print(f"Processing time: {mistral_result.get('processing_time', 0):.2f}s")
280+
281+
print(f"\nTotal time: {time.time() - start_time:.2f}s")
282+
else:
283+
result = run_ocr(args.image, args.model, args.prompt)
284+
print(f"\n{args.model} results:")
285+
print(f"Text: {result['raw_text']}")
286+
print(f"Entities: {result.get('entities', [])}")
287+
print(f"Processing time: {result.get('processing_time', 0):.2f}s")

0 commit comments

Comments
 (0)