Skip to content

Commit 70c6f45

Browse files
author
marwan37
committed
update ocr pipeline
1 parent cb22d7d commit 70c6f45

File tree

1 file changed

+34
-50
lines changed

1 file changed

+34
-50
lines changed

omni-reader/pipelines/ocr_pipeline.py

Lines changed: 34 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -38,94 +38,78 @@
3838
def ocr_comparison_pipeline(
3939
image_paths: Optional[List[str]] = None,
4040
image_folder: Optional[str] = None,
41-
image_patterns: Optional[List[str]] = None,
4241
custom_prompt: Optional[str] = None,
43-
ground_truth_texts: Optional[List[str]] = None,
42+
model1: str = "ollama/gemma3:27b",
43+
model2: str = "pixtral-12b-2409",
44+
ground_truth_model: str = "gpt-4o-mini",
4445
ground_truth_source: Literal["openai", "manual", "file", "none"] = "none",
4546
ground_truth_file: Optional[str] = None,
4647
save_ground_truth_data: bool = False,
47-
ground_truth_output_dir: str = "ground_truth",
48+
ground_truth_output_dir: str = "ocr_results",
4849
save_ocr_results_data: bool = False,
4950
ocr_results_output_dir: str = "ocr_results",
5051
save_visualization_data: bool = False,
5152
visualization_output_dir: str = "visualizations",
5253
) -> None:
53-
"""Run OCR comparison pipeline between Gemma3 and Mistral models.
54+
"""Run OCR comparison pipeline between two configurable models.
5455
5556
Args:
5657
image_paths: Optional list of specific image paths to process
5758
image_folder: Optional folder to search for images
58-
image_patterns: Optional list of glob patterns to use when searching image_folder
5959
custom_prompt: Optional custom prompt to use for both models
60-
ground_truth_texts: Optional list of ground truth texts for evaluation (used when ground_truth_source="manual")
61-
ground_truth_source: Source of ground truth - "openai" to use GPT-4V, "manual" for user-provided texts,
60+
model1: Name of the first model to use (default: ollama/gemma3:27b)
61+
model2: Name of the second model to use (default: pixtral-12b-2409)
62+
ground_truth_model: Name of the model to use for ground truth when source is "openai"
63+
ground_truth_source: Source of ground truth - "openai" to use configured model, "manual" for user-provided texts,
6264
"file" to load from a saved JSON file, or "none" to skip ground truth evaluation
6365
ground_truth_file: Path to ground truth JSON file (used when ground_truth_source="file")
6466
save_ground_truth_data: Whether to save generated ground truth data for future use
6567
ground_truth_output_dir: Directory to save ground truth data
66-
save_ocr_results_data: Whether to save OCR results from Gemma and Mistral
68+
save_ocr_results_data: Whether to save OCR results from both models
6769
ocr_results_output_dir: Directory to save OCR results
6870
save_visualization_data: Whether to save HTML visualization to local file
6971
visualization_output_dir: Directory to save HTML visualization
7072
7173
Returns:
7274
None
7375
"""
74-
images = load_images(
75-
image_paths=image_paths,
76-
image_folder=image_folder,
77-
image_patterns=image_patterns,
78-
)
79-
80-
# Keep track of which models were run
76+
images = load_images(image_paths=image_paths, image_folder=image_folder)
8177
model_names = []
8278

83-
# Run models in parallel on all images using the unified OCR step
84-
gemma_results = run_ocr(images=images, model_name="ollama/gemma3:27b", custom_prompt=custom_prompt)
85-
model_names.append("ollama/gemma3:27b")
79+
model1_results = run_ocr(images=images, model_name=model1, custom_prompt=custom_prompt)
80+
model_names.append(model1)
8681

87-
mistral_results = run_ocr(images=images, model_name="pixtral-12b-2409", custom_prompt=custom_prompt)
88-
model_names.append("pixtral-12b-2409")
82+
model2_results = run_ocr(images=images, model_name=model2, custom_prompt=custom_prompt)
83+
model_names.append(model2)
8984

9085
# Handle ground truth based on the selected source
9186
ground_truth = None
9287
openai_results = None
9388

9489
if ground_truth_source == "openai":
95-
openai_results = run_ocr(images=images, model_name="gpt-4o-mini", custom_prompt=custom_prompt)
90+
openai_results = run_ocr(
91+
images=images, model_name=ground_truth_model, custom_prompt=custom_prompt
92+
)
9693
ground_truth = openai_results
97-
model_names.append("gpt-4o-mini")
98-
99-
elif ground_truth_source == "manual" and ground_truth_texts:
100-
ground_truth_data = []
101-
for i, (text, image_path) in enumerate(zip(ground_truth_texts, images)):
102-
ground_truth_data.append(
103-
{
104-
"id": i,
105-
"image_name": os.path.basename(image_path),
106-
"raw_text": text,
107-
"confidence": 1.0, # Manual ground truth has perfect confidence
108-
}
109-
)
110-
ground_truth_df = pl.DataFrame(ground_truth_data)
111-
ground_truth = {"ground_truth_results": ground_truth_df}
112-
94+
model_names.append(ground_truth_model)
11395
elif ground_truth_source == "file" and ground_truth_file:
11496
ground_truth = load_ground_truth_file(filepath=ground_truth_file)
11597

11698
# Evaluate models
11799
visualization = evaluate_models(
118-
gemma_results=gemma_results,
119-
mistral_results=mistral_results,
120-
ground_truth=ground_truth,
100+
model1_df=model1_results,
101+
model2_df=model2_results,
102+
ground_truth_df=ground_truth,
103+
model1_name=model1,
104+
model2_name=model2,
121105
)
122106

123107
# Save OCR results if requested
124108
if save_ocr_results_data or save_ground_truth_data:
125109
save_ocr_results(
126-
gemma_results=gemma_results,
127-
mistral_results=mistral_results,
128-
openai_results=openai_results,
110+
model1_results=model1_results,
111+
model2_results=model2_results,
112+
ground_truth_results=ground_truth,
129113
model_names=model_names,
130114
output_dir=ocr_results_output_dir,
131115
ground_truth_output_dir=ground_truth_output_dir,
@@ -134,10 +118,7 @@ def ocr_comparison_pipeline(
134118

135119
# Save HTML visualization if requested
136120
if save_visualization_data:
137-
save_visualization(
138-
visualization=visualization,
139-
output_dir=visualization_output_dir,
140-
)
121+
save_visualization(visualization, output_dir=visualization_output_dir)
141122

142123

143124
def run_ocr_pipeline(config: Dict[str, Any]) -> None:
@@ -152,15 +133,18 @@ def run_ocr_pipeline(config: Dict[str, Any]) -> None:
152133
ocr_comparison_pipeline(
153134
image_paths=config["input"].get("image_paths"),
154135
image_folder=config["input"].get("image_folder"),
155-
image_patterns=config["input"].get("image_patterns"),
156136
custom_prompt=config["models"].get("custom_prompt"),
157-
ground_truth_texts=config["ground_truth"].get("texts"),
137+
model1=config["models"].get("model1", "ollama/gemma3:27b"),
138+
model2=config["models"].get("model2", "pixtral-12b-2409"),
139+
ground_truth_model=config["models"].get("ground_truth_model", "gpt-4o-mini"),
158140
ground_truth_source=config["ground_truth"].get("source", "none"),
159141
ground_truth_file=config["ground_truth"].get("file"),
160142
save_ground_truth_data=config["output"]["ground_truth"].get("save", False),
161143
ground_truth_output_dir=config["output"]["ground_truth"].get("directory", "ocr_results"),
162144
save_ocr_results_data=config["output"]["ocr_results"].get("save", False),
163145
ocr_results_output_dir=config["output"]["ocr_results"].get("directory", "ocr_results"),
164146
save_visualization_data=config["output"]["visualization"].get("save", False),
165-
visualization_output_dir=config["output"]["visualization"].get("directory", "visualizations"),
147+
visualization_output_dir=config["output"]["visualization"].get(
148+
"directory", "visualizations"
149+
),
166150
)

0 commit comments

Comments
 (0)