Skip to content

Commit a33b38c

Browse files
author
marwan37
committed
update run and run_compare_ocr entrypoint files
1 parent 70c6f45 commit a33b38c

File tree

2 files changed

+27
-140
lines changed

2 files changed

+27
-140
lines changed

omni-reader/run.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232

3333
def main():
3434
"""Run the OCR comparison pipeline."""
35-
parser = argparse.ArgumentParser(description="Run OCR comparison between Mistral and Gemma3 using ZenML")
35+
parser = argparse.ArgumentParser(
36+
description="Run OCR comparison between Mistral and Gemma3 using ZenML"
37+
)
3638

3739
# Config file options
3840
config_group = parser.add_argument_group("Configuration")
@@ -72,7 +74,7 @@ def main():
7274
gt_group.add_argument(
7375
"--ground-truth-dir",
7476
type=str,
75-
default="ground_truth",
77+
default="ocr_results",
7678
help="Directory to look for ground truth files (for --list-ground-truth-files)",
7779
)
7880

omni-reader/run_compare_ocr.py

Lines changed: 23 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -5,111 +5,50 @@
55
import time
66
from typing import Any, Dict, List, Optional
77

8-
# For faster performance in interactive mode without ZenML overhead,
9-
# we implement the OCR functions directly here
108
import instructor
11-
import polars as pl
129
from dotenv import load_dotenv
1310
from litellm import completion
1411
from mistralai import Mistral
1512
from PIL import Image
1613

17-
from schemas.image_description import ImageDescription
1814
from utils.encode_image import encode_image
19-
from utils.metrics import compare_results
20-
from utils.prompt import get_prompt
15+
from utils.prompt import ImageDescription, get_prompt
2116

2217
load_dotenv()
2318

2419

25-
def run_gemma3_ocr_direct(
20+
def run_ocr_from_ui(
2621
image: str | Image.Image,
22+
model: str = "gemma3",
2723
custom_prompt: Optional[str] = None,
2824
) -> Dict[str, Any]:
29-
"""Extract text directly using gemma3 model.
30-
31-
Args:
32-
image: Path to image or PIL image
33-
custom_prompt: Optional custom prompt
34-
35-
Returns:
36-
Dict with extraction results
37-
"""
38-
start_time = time.time()
39-
content_type, image_base64 = encode_image(image)
40-
41-
client = instructor.from_litellm(completion)
42-
model_name = "ollama/gemma3:27b"
43-
44-
prompt = custom_prompt if custom_prompt else get_prompt()
45-
46-
try:
47-
response = client.chat.completions.create(
48-
model=model_name,
49-
response_model=ImageDescription,
50-
messages=[
51-
{
52-
"role": "user",
53-
"content": [
54-
{"type": "text", "text": prompt},
55-
{
56-
"type": "image_url",
57-
"image_url": f"data:{content_type};base64,{image_base64}",
58-
},
59-
],
60-
}
61-
],
62-
)
63-
64-
processing_time = time.time() - start_time
65-
66-
result = {
67-
"raw_text": response.raw_text if response.raw_text else "No text found",
68-
"description": response.description if response.description else "No description found",
69-
"entities": response.entities if response.entities else [],
70-
"processing_time": processing_time,
71-
"model": model_name,
72-
}
73-
74-
return result
75-
except Exception as e:
76-
error_message = f"An unexpected error occurred: {str(e)}"
77-
return {
78-
"raw_text": "Error: Failed to extract text",
79-
"description": "Error: Failed to extract description",
80-
"entities": [],
81-
"error": error_message,
82-
"processing_time": time.time() - start_time,
83-
"model": model_name,
84-
}
25+
"""Extract text directly using OCR model.
8526
86-
87-
def run_mistral_ocr_direct(
88-
image: str | Image.Image,
89-
custom_prompt: Optional[str] = None,
90-
) -> Dict[str, Any]:
91-
"""Extract text directly using mistral model.
27+
This function is designed for use in the streamlit app.
9228
9329
Args:
9430
image: Path to image or PIL image
9531
custom_prompt: Optional custom prompt
96-
32+
model: Name of the model to use
9733
Returns:
9834
Dict with extraction results
9935
"""
10036
start_time = time.time()
10137
content_type, image_base64 = encode_image(image)
10238

103-
mistral_client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
104-
client = instructor.from_mistral(mistral_client)
105-
106-
model_name = "pixtral-12b-2409"
39+
if "gemma" in model.lower():
40+
client = instructor.from_ollama(completion)
41+
elif "mistral" in model.lower() or "pixtral" in model.lower():
42+
mistral_client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
43+
client = instructor.from_mistral(mistral_client)
44+
else:
45+
raise ValueError(f"Unsupported model: {model}")
10746

10847
prompt = custom_prompt if custom_prompt else get_prompt()
10948

11049
try:
11150
response = client.chat.completions.create(
112-
model=model_name,
51+
model=model,
11352
response_model=ImageDescription,
11453
messages=[
11554
{
@@ -125,63 +64,34 @@ def run_mistral_ocr_direct(
12564
],
12665
)
12766

128-
print(f"Response: {response}")
129-
13067
processing_time = time.time() - start_time
13168

13269
result = {
13370
"raw_text": response.raw_text if response.raw_text else "No text found",
134-
"description": response.description if response.description else "No description found",
135-
"entities": response.entities if response.entities else [],
13671
"processing_time": processing_time,
137-
"model": model_name,
72+
"model": model,
13873
}
13974

14075
return result
14176
except Exception as e:
14277
error_message = f"An unexpected error occurred: {str(e)}"
14378
return {
14479
"raw_text": "Error: Failed to extract text",
145-
"description": "Error: Failed to extract description",
146-
"entities": [],
14780
"error": error_message,
14881
"processing_time": time.time() - start_time,
149-
"model": model_name,
82+
"model": model,
15083
}
15184

15285

153-
def run_ocr(
154-
image: str | Image.Image,
155-
model: str = "gemma3",
156-
custom_prompt: Optional[str] = None,
157-
) -> Dict[str, Any]:
158-
"""Run OCR using either Gemma3 or Mistral model.
159-
160-
Args:
161-
image: Path to image or PIL image
162-
model: Model to use ('gemma3' or 'mistral')
163-
custom_prompt: Optional custom prompt
164-
165-
Returns:
166-
Dict with extraction results
167-
"""
168-
if model.lower() == "gemma3":
169-
return run_gemma3_ocr_direct(image=image, custom_prompt=custom_prompt)
170-
else:
171-
return run_mistral_ocr_direct(image=image, custom_prompt=custom_prompt)
172-
173-
17486
def compare_models(
17587
image_paths: List[str],
17688
custom_prompt: Optional[str] = None,
177-
ground_truth_texts: Optional[List[str]] = None,
17889
) -> Dict[str, Any]:
17990
"""Compare Gemma3 and Mistral OCR capabilities on a list of images.
18091
18192
Args:
18293
image_paths: List of paths to images
18394
custom_prompt: Optional custom prompt to use for both models
184-
ground_truth_texts: Optional list of ground truth texts
18595
Returns:
18696
Dictionary with comparison results
18797
"""
@@ -197,14 +107,14 @@ def compare_models(
197107
print(f"Processing image {i + 1}/{len(image_paths)}: {image_name}")
198108

199109
# Run both models
200-
gemma_result = run_ocr(
110+
gemma_result = run_ocr_from_ui(
201111
image=image_path,
202-
model="gemma3",
112+
model_name="ollama/gemma3:27b",
203113
custom_prompt=custom_prompt,
204114
)
205-
mistral_result = run_ocr(
115+
mistral_result = run_ocr_from_ui(
206116
image=image_path,
207-
model="mistral",
117+
model_name="pixtral-12b-2409",
208118
custom_prompt=custom_prompt,
209119
)
210120

@@ -213,41 +123,19 @@ def compare_models(
213123
"id": i,
214124
"image_name": image_name,
215125
"gemma_text": gemma_result["raw_text"],
216-
"gemma_entities": ", ".join(gemma_result.get("entities", [])),
217126
"gemma_processing_time": gemma_result.get("processing_time", 0),
218127
}
219128

220129
mistral_entry = {
221130
"id": i,
222131
"image_name": image_name,
223132
"mistral_text": mistral_result["raw_text"],
224-
"mistral_entities": ", ".join(mistral_result.get("entities", [])),
225133
"mistral_processing_time": mistral_result.get("processing_time", 0),
226134
}
227135

228136
results["gemma_results"].append(gemma_entry)
229137
results["mistral_results"].append(mistral_entry)
230138

231-
# Add ground truth if available
232-
if ground_truth_texts and i < len(ground_truth_texts):
233-
results["ground_truth"].append(
234-
{
235-
"id": i,
236-
"image_name": image_name,
237-
"ground_truth_text": ground_truth_texts[i],
238-
}
239-
)
240-
241-
# Calculate metrics
242-
metrics = compare_results(
243-
ground_truth_texts[i],
244-
gemma_result["raw_text"],
245-
mistral_result["raw_text"],
246-
)
247-
print(f"Metrics for {image_name}:")
248-
for key, value in metrics.items():
249-
print(f" {key}: {value:.4f}")
250-
251139
return results
252140

253141

@@ -266,22 +154,19 @@ def compare_models(
266154

267155
if args.model.lower() == "both":
268156
start_time = time.time()
269-
gemma_result = run_ocr(args.image, "gemma3", args.prompt)
270-
mistral_result = run_ocr(args.image, "mistral", args.prompt)
157+
gemma_result = run_ocr_from_ui(args.image, "ollama/gemma3:27b", args.prompt)
158+
mistral_result = run_ocr_from_ui(args.image, "pixtral-12b-2409", args.prompt)
271159
print("\nGemma3 results:")
272160
print(f"Text: {gemma_result['raw_text']}")
273-
print(f"Entities: {gemma_result.get('entities', [])}")
274161
print(f"Processing time: {gemma_result.get('processing_time', 0):.2f}s")
275162

276163
print("\nMistral results:")
277164
print(f"Text: {mistral_result['raw_text']}")
278-
print(f"Entities: {mistral_result.get('entities', [])}")
279165
print(f"Processing time: {mistral_result.get('processing_time', 0):.2f}s")
280166

281167
print(f"\nTotal time: {time.time() - start_time:.2f}s")
282168
else:
283-
result = run_ocr(args.image, args.model, args.prompt)
169+
result = run_ocr_from_ui(args.image, args.model, args.prompt)
284170
print(f"\n{args.model} results:")
285171
print(f"Text: {result['raw_text']}")
286-
print(f"Entities: {result.get('entities', [])}")
287172
print(f"Processing time: {result.get('processing_time', 0):.2f}s")

0 commit comments

Comments
 (0)