Skip to content

Commit 6a0295f

Browse files
author
marwan37
committed
add model_configs util for centralized model configuration and client fetching
1 parent 58ac407 commit 6a0295f

File tree

1 file changed

+248
-0
lines changed

1 file changed

+248
-0
lines changed

omni-reader/utils/model_configs.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2025. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Model configuration utilities for OCR operations."""
17+
18+
import os
19+
from dataclasses import dataclass, field
20+
from typing import Any, Dict, Optional, Tuple
21+
22+
import instructor
23+
import requests
24+
from mistralai import Mistral
25+
from openai import OpenAI
26+
from zenml.logger import get_logger
27+
28+
logger = get_logger(__name__)
29+
30+
31+
@dataclass
32+
class ModelConfig:
33+
"""Configuration for OCR models."""
34+
35+
name: str
36+
display: str
37+
provider: str
38+
prefix: str
39+
logo: Optional[str] = None
40+
base_url: Optional[str] = None
41+
additional_params: Dict[str, Any] = field(default_factory=dict)
42+
default_confidence: float = 0.5
43+
44+
def get_client(self):
45+
"""Get the appropriate client for this model configuration."""
46+
if self.provider == "openai":
47+
return get_openai_client()
48+
elif self.provider == "mistral":
49+
return get_mistral_client()
50+
else:
51+
raise ValueError(f"Unsupported provider: {self.provider}")
52+
53+
def process_image(self, prompt, image_base64, content_type="image/jpeg"):
54+
"""Process an image with this model."""
55+
if self.provider == "ollama":
56+
return self._process_ollama(prompt, image_base64)
57+
else:
58+
return self._process_api_based(prompt, image_base64, content_type)
59+
60+
def _process_ollama(self, prompt, image_base64):
61+
"""Process an image with an Ollama model."""
62+
from utils.ocr_processing import try_extract_json_from_response
63+
64+
base_url = self.base_url or DOCKER_BASE_URL
65+
66+
payload = {
67+
"model": self.name,
68+
"prompt": prompt,
69+
"stream": False,
70+
"images": [image_base64],
71+
}
72+
73+
try:
74+
response = requests.post(
75+
base_url,
76+
json=payload,
77+
timeout=120, # Increase timeout for larger images
78+
)
79+
response.raise_for_status()
80+
res = response.json().get("response", "")
81+
result_json = try_extract_json_from_response(res)
82+
83+
return result_json
84+
except Exception as e:
85+
logger.error(f"Error processing with Ollama model {self.name}: {str(e)}")
86+
return {"raw_text": f"Error: {str(e)}", "confidence": 0.0}
87+
88+
def _process_api_based(self, prompt, image_base64, content_type):
89+
"""Process an image with an API-based model (OpenAI, Mistral)."""
90+
from utils.ocr_processing import try_extract_json_from_response
91+
from utils.prompt import ImageDescription
92+
93+
client = self.get_client()
94+
95+
messages = [
96+
{
97+
"role": "user",
98+
"content": [
99+
{"type": "text", "text": prompt},
100+
{
101+
"type": "image_url",
102+
"image_url": {"url": f"data:{content_type};base64,{image_base64}"},
103+
},
104+
],
105+
}
106+
]
107+
108+
try:
109+
response = client.chat.completions.create(
110+
model=self.name,
111+
messages=messages,
112+
response_model=ImageDescription,
113+
**self.additional_params,
114+
)
115+
116+
result_json = try_extract_json_from_response(response)
117+
return result_json
118+
except Exception as e:
119+
logger.error(f"Error processing with {self.provider} model {self.name}: {str(e)}")
120+
return {"raw_text": f"Error: {str(e)}", "confidence": 0.0}
121+
122+
123+
# --------- Ollama models ---------
124+
DOCKER_BASE_URL = "http://host.docker.internal:11434/api/generate"
125+
BASE_URL = "http://localhost:11434/api/generate"
126+
127+
128+
def get_openai_client():
129+
"""Get an OpenAI client with instructor integration."""
130+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
131+
return instructor.from_openai(client)
132+
133+
134+
def get_mistral_client():
135+
"""Get a Mistral client with instructor integration."""
136+
client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
137+
return instructor.from_mistral(client)
138+
139+
140+
def get_model_info(model_name: str) -> Tuple[str, str]:
141+
"""Returns a tuple (display, prefix) for a given model name.
142+
143+
Args:
144+
model_name: The name of the model
145+
146+
Returns:
147+
A tuple (display, prefix)
148+
"""
149+
if model_name in MODEL_CONFIGS:
150+
config = MODEL_CONFIGS[model_name]
151+
return config.display, config.prefix
152+
153+
# Fallback: Generate display name and prefix from model name
154+
if "/" in model_name:
155+
model_part = model_name.split("/")[-1]
156+
157+
if ":" in model_part:
158+
display = model_part.split(":")[0]
159+
else:
160+
display = model_part
161+
162+
display = display.replace("-", " ").title()
163+
else:
164+
display = model_name.split("-")[0]
165+
if ":" in display:
166+
display = display.split(":")[0]
167+
display = display.title()
168+
169+
prefix = display.lower().replace(" ", "_").replace("-", "_")
170+
171+
return display, prefix
172+
173+
174+
# --------- models ---------
175+
MODEL_CONFIGS = {
176+
"pixtral-12b-2409": ModelConfig(
177+
name="pixtral-12b-2409",
178+
display="Mistral Pixtral 12B",
179+
provider="mistral",
180+
prefix="pixtral_12b_2409",
181+
logo="mistral.svg",
182+
),
183+
"gpt-4o-mini": ModelConfig(
184+
name="gpt-4o-mini",
185+
display="GPT-4o-mini",
186+
provider="openai",
187+
prefix="openai_gpt_4o_mini",
188+
logo="openai.svg",
189+
),
190+
"gemma3:12b": ModelConfig(
191+
name="gemma3:12b",
192+
display="Gemma 3 12B",
193+
provider="ollama",
194+
prefix="gemma3_12b",
195+
logo="gemma.svg",
196+
base_url=BASE_URL,
197+
),
198+
"llama3.2-vision:11b": ModelConfig(
199+
name="llama3.2-vision:11b",
200+
display="Llama 3.2 Vision 11B",
201+
provider="ollama",
202+
prefix="llama3_2_vision_11b",
203+
logo="ollama.svg",
204+
base_url=BASE_URL,
205+
),
206+
"granite3.2-vision": ModelConfig(
207+
name="granite3.2-vision",
208+
display="Granite 3.2 Vision",
209+
provider="ollama",
210+
prefix="granite3_2_vision",
211+
logo="granite.svg",
212+
base_url=BASE_URL,
213+
),
214+
"llava:7b": ModelConfig(
215+
name="llava:7b",
216+
display="Llava 7B",
217+
provider="ollama",
218+
prefix="llava_7b",
219+
logo="llava.svg",
220+
base_url=BASE_URL,
221+
),
222+
"moondream": ModelConfig(
223+
name="moondream",
224+
display="Moondream",
225+
provider="ollama",
226+
prefix="moondream_v",
227+
logo="moondream.svg",
228+
base_url=BASE_URL,
229+
),
230+
"minicpm-v": ModelConfig(
231+
name="minicpm-v",
232+
display="MiniCPM-V",
233+
provider="ollama",
234+
prefix="minicpm_v",
235+
logo="ollama.svg",
236+
base_url=BASE_URL,
237+
),
238+
"qwen2:latest": ModelConfig(
239+
name="qwen2:latest",
240+
display="Qwen2",
241+
provider="ollama",
242+
prefix="qwen2_latest",
243+
logo="qwen.svg",
244+
base_url=BASE_URL,
245+
),
246+
}
247+
248+
DEFAULT_MODEL = MODEL_CONFIGS["llama3.2-vision:11b"]

0 commit comments

Comments
 (0)