Skip to content

Commit b65894b

Browse files
author
marwan37
committed
refactor: revert to use litellm+instructor for ollama models
1 parent 16f3b77 commit b65894b

File tree

4 files changed

+33
-118
lines changed

4 files changed

+33
-118
lines changed

omni-reader/app.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,8 @@ def has_no_text(result):
176176
col1, col2 = st.columns(2)
177177

178178
start = time.time()
179-
# gemma_result = run_ocr_from_ui(
180-
# image=image, model="ollama/gemma3:27b", custom_prompt=prompt_param
181-
# )
182-
gemma_result = run_ollama_ocr_from_ui(
183-
image,
184-
model="gemma3:27b",
185-
custom_prompt=prompt_param,
179+
gemma_result = run_ocr_from_ui(
180+
image=image, model="ollama/gemma3:27b", custom_prompt=prompt_param
186181
)
187182
gemma_time = time.time() - start
188183

@@ -266,14 +261,9 @@ def has_no_text(result):
266261
with st.spinner(f"Processing image with {model_choice}..."):
267262
try:
268263
start = time.time()
269-
if "gemma" in model_param.lower():
270-
response = run_ollama_ocr_from_ui(
271-
image, model="gemma3:27b", custom_prompt=prompt_param
272-
)
273-
else:
274-
response = run_ocr_from_ui(
275-
image=image, model=model_param, custom_prompt=prompt_param
276-
)
264+
response = run_ocr_from_ui(
265+
image=image, model=model_param, custom_prompt=prompt_param
266+
)
277267
proc_time = time.time() - start
278268
st.session_state["ocr_result"] = response
279269

omni-reader/configs/ocr_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Image input configuration
44
input:
55
image_paths: [] # List of specific image paths to process
6-
image_folder: "assets/handwriting" # Folder containing images to process
6+
image_folder: "assets/street_signs" # Folder containing images to process
77

88
models:
99
custom_prompt: null # Optional custom prompt to use for both models

omni-reader/run_compare_ocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def run_ocr_from_ui(
8585

8686
def run_ollama_ocr_from_ui(
8787
image: str | Image.Image,
88-
model: str = "gemma3:27b",
88+
model: str = "ollama/gemma3:27b",
8989
custom_prompt: Optional[str] = None,
9090
) -> Dict[str, Any]:
9191
"""Run OCR using Ollama.

omni-reader/utils/ocr_model_utils.py

Lines changed: 26 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import time
2424
from typing import Dict, List, Optional
2525

26-
import ollama
2726
import polars as pl
2827
from dotenv import load_dotenv
2928
from zenml import log_metadata
@@ -157,100 +156,6 @@ def log_summary_metadata(
157156
)
158157

159158

160-
def process_with_ollama(
161-
model_name: str,
162-
image_path: str,
163-
prompt: str,
164-
model_config: ModelConfig,
165-
) -> Dict:
166-
"""Process an image with Ollama.
167-
168-
Args:
169-
model_name: Name of the Ollama model
170-
image_path: Path to the image file
171-
prompt: Prompt text
172-
model_config: Model configuration
173-
174-
Returns:
175-
Dict with OCR results
176-
"""
177-
_, image_base64 = encode_image(image_path)
178-
179-
ollama_params = {
180-
"model": model_name,
181-
"messages": [
182-
{
183-
"role": "user",
184-
"content": prompt,
185-
"images": [image_base64],
186-
"format": ImageDescription.model_json_schema(),
187-
}
188-
],
189-
}
190-
191-
if model_config.additional_params:
192-
ollama_params.update(model_config.additional_params)
193-
194-
try:
195-
response = ollama.chat(**ollama_params)
196-
result = try_extract_json_from_response(response.message.content)
197-
return result
198-
except Exception as e:
199-
error_msg = f"Error with Ollama OCR: {str(e)}"
200-
logger.error(error_msg)
201-
return {
202-
"raw_text": f"Error: {error_msg}",
203-
"confidence": 0.0,
204-
}
205-
206-
207-
def process_with_client(
208-
client,
209-
model_name: str,
210-
image_path: str,
211-
prompt: str,
212-
model_config: ModelConfig,
213-
) -> ImageDescription | Dict:
214-
"""Process images with an API client (OpenAI, Mistral, etc.).
215-
216-
Args:
217-
client: API client
218-
model_name: Name of the model
219-
image_path: Path to the image file
220-
prompt: Prompt text
221-
model_config: Model configuration
222-
223-
Returns:
224-
API response processed into ImageDescription or Dict
225-
"""
226-
content_type, image_base64 = encode_image(image_path)
227-
228-
params = {
229-
"model": model_name,
230-
"response_model": ImageDescription,
231-
**({"max_tokens": model_config.max_tokens} if model_config.max_tokens else {}),
232-
**(model_config.additional_params or {}),
233-
}
234-
235-
return client.chat.completions.create(
236-
**params,
237-
messages=[
238-
{
239-
"role": "user",
240-
"content": [
241-
{"type": "text", "text": prompt},
242-
{
243-
"type": "image_url",
244-
"image_url": {
245-
"url": f"data:{content_type};base64,{image_base64}",
246-
},
247-
},
248-
],
249-
}
250-
],
251-
)
252-
253-
254159
def process_images_with_model(
255160
model_config: ModelConfig,
256161
images: List[str],
@@ -277,18 +182,38 @@ def process_images_with_model(
277182
processing_times = []
278183
confidence_scores = []
279184

280-
if "ollama" not in model_name:
281-
client = model_config.client_factory()
185+
client = model_config.client_factory()
282186

283187
for i, image_path in enumerate(images):
284188
start_time = time.time()
285189
image_name = os.path.basename(image_path)
286190

191+
content_type, image_base64 = encode_image(image_path)
192+
params = {
193+
"model": model_name,
194+
"response_model": ImageDescription,
195+
**({"max_tokens": model_config.max_tokens} if model_config.max_tokens else {}),
196+
**(model_config.additional_params or {}),
197+
}
198+
287199
try:
288-
if "gemma" in model_name:
289-
response = process_with_ollama(model_name, image_path, prompt, model_config)
290-
else:
291-
response = process_with_client(client, model_name, image_path, prompt, model_config)
200+
response = client.chat.completions.create(
201+
**params,
202+
messages=[
203+
{
204+
"role": "user",
205+
"content": [
206+
{"type": "text", "text": prompt},
207+
{
208+
"type": "image_url",
209+
"image_url": {
210+
"url": f"data:{content_type};base64,{image_base64}",
211+
},
212+
},
213+
],
214+
}
215+
],
216+
)
292217

293218
processing_time = time.time() - start_time
294219
processing_times.append(processing_time)

0 commit comments

Comments
 (0)