55import time
66from typing import Any , Dict , List , Optional
77
8- # For faster performance in interactive mode without ZenML overhead,
9- # we implement the OCR functions directly here
108import instructor
11- import polars as pl
129from dotenv import load_dotenv
1310from litellm import completion
1411from mistralai import Mistral
1512from PIL import Image
1613
17- from schemas .image_description import ImageDescription
1814from utils .encode_image import encode_image
19- from utils .metrics import compare_results
20- from utils .prompt import get_prompt
15+ from utils .prompt import ImageDescription , get_prompt
2116
2217load_dotenv ()
2318
2419
25- def run_gemma3_ocr_direct (
20+ def run_ocr_from_ui (
2621 image : str | Image .Image ,
22+ model : str = "gemma3" ,
2723 custom_prompt : Optional [str ] = None ,
2824) -> Dict [str , Any ]:
29- """Extract text directly using gemma3 model.
30-
31- Args:
32- image: Path to image or PIL image
33- custom_prompt: Optional custom prompt
34-
35- Returns:
36- Dict with extraction results
37- """
38- start_time = time .time ()
39- content_type , image_base64 = encode_image (image )
40-
41- client = instructor .from_litellm (completion )
42- model_name = "ollama/gemma3:27b"
43-
44- prompt = custom_prompt if custom_prompt else get_prompt ()
45-
46- try :
47- response = client .chat .completions .create (
48- model = model_name ,
49- response_model = ImageDescription ,
50- messages = [
51- {
52- "role" : "user" ,
53- "content" : [
54- {"type" : "text" , "text" : prompt },
55- {
56- "type" : "image_url" ,
57- "image_url" : f"data:{ content_type } ;base64,{ image_base64 } " ,
58- },
59- ],
60- }
61- ],
62- )
63-
64- processing_time = time .time () - start_time
65-
66- result = {
67- "raw_text" : response .raw_text if response .raw_text else "No text found" ,
68- "description" : response .description if response .description else "No description found" ,
69- "entities" : response .entities if response .entities else [],
70- "processing_time" : processing_time ,
71- "model" : model_name ,
72- }
73-
74- return result
75- except Exception as e :
76- error_message = f"An unexpected error occurred: { str (e )} "
77- return {
78- "raw_text" : "Error: Failed to extract text" ,
79- "description" : "Error: Failed to extract description" ,
80- "entities" : [],
81- "error" : error_message ,
82- "processing_time" : time .time () - start_time ,
83- "model" : model_name ,
84- }
25+ """Extract text directly using OCR model.
8526
86-
87- def run_mistral_ocr_direct (
88- image : str | Image .Image ,
89- custom_prompt : Optional [str ] = None ,
90- ) -> Dict [str , Any ]:
91- """Extract text directly using mistral model.
27+ This function is designed for use in the streamlit app.
9228
9329 Args:
9430 image: Path to image or PIL image
9531 custom_prompt: Optional custom prompt
96-
32+ model: Name of the model to use
9733 Returns:
9834 Dict with extraction results
9935 """
10036 start_time = time .time ()
10137 content_type , image_base64 = encode_image (image )
10238
103- mistral_client = Mistral (api_key = os .getenv ("MISTRAL_API_KEY" ))
104- client = instructor .from_mistral (mistral_client )
105-
106- model_name = "pixtral-12b-2409"
39+ if "gemma" in model .lower ():
40+ client = instructor .from_ollama (completion )
41+ elif "mistral" in model .lower () or "pixtral" in model .lower ():
42+ mistral_client = Mistral (api_key = os .getenv ("MISTRAL_API_KEY" ))
43+ client = instructor .from_mistral (mistral_client )
44+ else :
45+ raise ValueError (f"Unsupported model: { model } " )
10746
10847 prompt = custom_prompt if custom_prompt else get_prompt ()
10948
11049 try :
11150 response = client .chat .completions .create (
112- model = model_name ,
51+ model = model ,
11352 response_model = ImageDescription ,
11453 messages = [
11554 {
@@ -125,63 +64,34 @@ def run_mistral_ocr_direct(
12564 ],
12665 )
12766
128- print (f"Response: { response } " )
129-
13067 processing_time = time .time () - start_time
13168
13269 result = {
13370 "raw_text" : response .raw_text if response .raw_text else "No text found" ,
134- "description" : response .description if response .description else "No description found" ,
135- "entities" : response .entities if response .entities else [],
13671 "processing_time" : processing_time ,
137- "model" : model_name ,
72+ "model" : model ,
13873 }
13974
14075 return result
14176 except Exception as e :
14277 error_message = f"An unexpected error occurred: { str (e )} "
14378 return {
14479 "raw_text" : "Error: Failed to extract text" ,
145- "description" : "Error: Failed to extract description" ,
146- "entities" : [],
14780 "error" : error_message ,
14881 "processing_time" : time .time () - start_time ,
149- "model" : model_name ,
82+ "model" : model ,
15083 }
15184
15285
153- def run_ocr (
154- image : str | Image .Image ,
155- model : str = "gemma3" ,
156- custom_prompt : Optional [str ] = None ,
157- ) -> Dict [str , Any ]:
158- """Run OCR using either Gemma3 or Mistral model.
159-
160- Args:
161- image: Path to image or PIL image
162- model: Model to use ('gemma3' or 'mistral')
163- custom_prompt: Optional custom prompt
164-
165- Returns:
166- Dict with extraction results
167- """
168- if model .lower () == "gemma3" :
169- return run_gemma3_ocr_direct (image = image , custom_prompt = custom_prompt )
170- else :
171- return run_mistral_ocr_direct (image = image , custom_prompt = custom_prompt )
172-
173-
17486def compare_models (
17587 image_paths : List [str ],
17688 custom_prompt : Optional [str ] = None ,
177- ground_truth_texts : Optional [List [str ]] = None ,
17889) -> Dict [str , Any ]:
17990 """Compare Gemma3 and Mistral OCR capabilities on a list of images.
18091
18192 Args:
18293 image_paths: List of paths to images
18394 custom_prompt: Optional custom prompt to use for both models
184- ground_truth_texts: Optional list of ground truth texts
18595 Returns:
18696 Dictionary with comparison results
18797 """
@@ -197,14 +107,14 @@ def compare_models(
197107 print (f"Processing image { i + 1 } /{ len (image_paths )} : { image_name } " )
198108
199109 # Run both models
200- gemma_result = run_ocr (
110+ gemma_result = run_ocr_from_ui (
201111 image = image_path ,
202- model = " gemma3" ,
112+ model_name = "ollama/ gemma3:27b " ,
203113 custom_prompt = custom_prompt ,
204114 )
205- mistral_result = run_ocr (
115+ mistral_result = run_ocr_from_ui (
206116 image = image_path ,
207- model = "mistral " ,
117+ model_name = "pixtral-12b-2409 " ,
208118 custom_prompt = custom_prompt ,
209119 )
210120
@@ -213,41 +123,19 @@ def compare_models(
213123 "id" : i ,
214124 "image_name" : image_name ,
215125 "gemma_text" : gemma_result ["raw_text" ],
216- "gemma_entities" : ", " .join (gemma_result .get ("entities" , [])),
217126 "gemma_processing_time" : gemma_result .get ("processing_time" , 0 ),
218127 }
219128
220129 mistral_entry = {
221130 "id" : i ,
222131 "image_name" : image_name ,
223132 "mistral_text" : mistral_result ["raw_text" ],
224- "mistral_entities" : ", " .join (mistral_result .get ("entities" , [])),
225133 "mistral_processing_time" : mistral_result .get ("processing_time" , 0 ),
226134 }
227135
228136 results ["gemma_results" ].append (gemma_entry )
229137 results ["mistral_results" ].append (mistral_entry )
230138
231- # Add ground truth if available
232- if ground_truth_texts and i < len (ground_truth_texts ):
233- results ["ground_truth" ].append (
234- {
235- "id" : i ,
236- "image_name" : image_name ,
237- "ground_truth_text" : ground_truth_texts [i ],
238- }
239- )
240-
241- # Calculate metrics
242- metrics = compare_results (
243- ground_truth_texts [i ],
244- gemma_result ["raw_text" ],
245- mistral_result ["raw_text" ],
246- )
247- print (f"Metrics for { image_name } :" )
248- for key , value in metrics .items ():
249- print (f" { key } : { value :.4f} " )
250-
251139 return results
252140
253141
@@ -266,22 +154,19 @@ def compare_models(
266154
267155 if args .model .lower () == "both" :
268156 start_time = time .time ()
269- gemma_result = run_ocr (args .image , "gemma3" , args .prompt )
270- mistral_result = run_ocr (args .image , "mistral " , args .prompt )
157+ gemma_result = run_ocr_from_ui (args .image , "ollama/ gemma3:27b " , args .prompt )
158+ mistral_result = run_ocr_from_ui (args .image , "pixtral-12b-2409 " , args .prompt )
271159 print ("\n Gemma3 results:" )
272160 print (f"Text: { gemma_result ['raw_text' ]} " )
273- print (f"Entities: { gemma_result .get ('entities' , [])} " )
274161 print (f"Processing time: { gemma_result .get ('processing_time' , 0 ):.2f} s" )
275162
276163 print ("\n Mistral results:" )
277164 print (f"Text: { mistral_result ['raw_text' ]} " )
278- print (f"Entities: { mistral_result .get ('entities' , [])} " )
279165 print (f"Processing time: { mistral_result .get ('processing_time' , 0 ):.2f} s" )
280166
281167 print (f"\n Total time: { time .time () - start_time :.2f} s" )
282168 else :
283- result = run_ocr (args .image , args .model , args .prompt )
169+ result = run_ocr_from_ui (args .image , args .model , args .prompt )
284170 print (f"\n { args .model } results:" )
285171 print (f"Text: { result ['raw_text' ]} " )
286- print (f"Entities: { result .get ('entities' , [])} " )
287172 print (f"Processing time: { result .get ('processing_time' , 0 ):.2f} s" )
0 commit comments