Skip to content

Commit d6cc858

Browse files
author
marwan37
committed
add loaders.py
1 parent 6ecc9b5 commit d6cc858

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

omni-reader/steps/loaders.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2025. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""This module contains the ground truth and OCR results loader steps."""
17+
18+
import glob
19+
import json
20+
import os
21+
from typing import Dict, List, Optional
22+
23+
import polars as pl
24+
from typing_extensions import Annotated
25+
from zenml import log_metadata, step
26+
from zenml.logger import get_logger
27+
28+
from utils.model_configs import get_model_prefix
29+
30+
logger = get_logger(__name__)
31+
32+
33+
@step()
34+
def load_images(
35+
image_paths: Optional[List[str]] = None,
36+
image_folder: Optional[str] = None,
37+
) -> List[str]:
38+
"""Load images for OCR processing.
39+
40+
This step loads images from specified paths or by searching for
41+
patterns in a given folder.
42+
43+
Args:
44+
image_paths: Optional list of specific image paths to load
45+
image_folder: Optional folder to search for images.
46+
47+
Returns:
48+
List of validated image file paths
49+
"""
50+
all_images = []
51+
52+
if image_paths:
53+
all_images.extend(image_paths)
54+
logger.info(f"Added {len(image_paths)} directly specified images")
55+
56+
if image_folder:
57+
patterns_to_use = ["*.jpg", "*.jpeg", "*.png", "*.webp", "*.tiff"]
58+
59+
for pattern in patterns_to_use:
60+
full_pattern = os.path.join(image_folder, pattern)
61+
matching_files = glob.glob(full_pattern)
62+
if matching_files:
63+
all_images.extend(matching_files)
64+
logger.info(f"Found {len(matching_files)} images matching pattern {pattern}")
65+
66+
# Validate image paths
67+
valid_images = []
68+
for path in all_images:
69+
if os.path.isfile(path):
70+
valid_images.append(path)
71+
else:
72+
logger.warning(f"Image not found: {path}")
73+
74+
# Log metadata about the loaded images
75+
image_names = [os.path.basename(path) for path in valid_images]
76+
image_extensions = [os.path.splitext(path)[1].lower() for path in valid_images]
77+
78+
extension_counts = {}
79+
for ext in image_extensions:
80+
if ext in extension_counts:
81+
extension_counts[ext] += 1
82+
else:
83+
extension_counts[ext] = 1
84+
85+
log_metadata(
86+
metadata={
87+
"loaded_images": {
88+
"total_count": len(valid_images),
89+
"extensions": extension_counts,
90+
"image_names": image_names,
91+
}
92+
}
93+
)
94+
95+
logger.info(f"Successfully loaded {len(valid_images)} valid images")
96+
97+
return valid_images
98+
99+
100+
@step(enable_cache=False)
101+
def load_ground_truth_file(
102+
filepath: str,
103+
) -> Annotated[pl.DataFrame, "ground_truth"]:
104+
"""Load ground truth data from a JSON file.
105+
106+
Args:
107+
filepath: Path to the ground truth file
108+
109+
Returns:
110+
pl.DataFrame containing ground truth results
111+
"""
112+
from utils.io_utils import load_ocr_data_from_json
113+
114+
if not os.path.exists(filepath):
115+
raise FileNotFoundError(f"Ground truth file not found: {filepath}")
116+
117+
df = load_ocr_data_from_json(filepath)
118+
119+
log_metadata(metadata={"ground_truth_loaded": {"path": filepath, "count": len(df)}})
120+
121+
return df
122+
123+
124+
@step(enable_cache=False)
125+
def load_ground_truth_texts(
126+
model_results: Dict[str, pl.DataFrame],
127+
ground_truth_folder: Optional[str] = None,
128+
ground_truth_files: Optional[List[str]] = None,
129+
) -> Annotated[pl.DataFrame, "ground_truth"]:
130+
"""Load ground truth texts using image names found in model results."""
131+
if not ground_truth_folder and not ground_truth_files:
132+
raise ValueError("Either ground_truth_folder or ground_truth_files must be provided")
133+
134+
# Grab image names from any model result
135+
sample_model_df = next(iter(model_results.values()))
136+
image_names = sample_model_df.select("image_name").to_series().to_list()
137+
138+
gt_path_map = {}
139+
140+
if ground_truth_folder:
141+
for f in os.listdir(ground_truth_folder):
142+
if f.endswith(".txt"):
143+
base = os.path.splitext(f)[0]
144+
gt_path_map[base] = os.path.join(ground_truth_folder, f)
145+
elif ground_truth_files:
146+
for path in ground_truth_files:
147+
base = os.path.splitext(os.path.basename(path))[0]
148+
gt_path_map[base] = path
149+
150+
data = []
151+
missing = []
152+
153+
for i, img_name in enumerate(image_names):
154+
base_name = os.path.splitext(img_name)[0]
155+
gt_path = gt_path_map.get(base_name)
156+
157+
if not gt_path or not os.path.exists(gt_path):
158+
missing.append(img_name)
159+
continue
160+
161+
try:
162+
with open(gt_path, "r", encoding="utf-8") as f:
163+
text = f.read().strip()
164+
data.append(
165+
{
166+
"id": i,
167+
"image_name": img_name,
168+
"raw_text": text,
169+
"processing_time": 0,
170+
"confidence": 1.0,
171+
}
172+
)
173+
except Exception as e:
174+
logger.warning(f"Failed to read ground truth for {img_name}: {e}")
175+
176+
if missing:
177+
logger.warning(
178+
f"Missing ground truth files for {len(missing)} images: {missing[:5]}{'...' if len(missing) > 5 else ''}"
179+
)
180+
181+
if not data:
182+
raise ValueError("No ground truth files could be loaded.")
183+
184+
return pl.DataFrame(data)
185+
186+
187+
@step(enable_cache=False)
188+
def load_ocr_results(
189+
model_names: List[str],
190+
results_dir: str = "ocr_results",
191+
result_files: Optional[List[str]] = None,
192+
) -> Dict[str, pl.DataFrame]:
193+
"""Load OCR results from previously saved JSON files."""
194+
results = {}
195+
model_to_prefix = {model: get_model_prefix(model) for model in model_names}
196+
197+
if result_files:
198+
for file_path in result_files:
199+
if not os.path.exists(file_path):
200+
logger.warning(f"Result file not found: {file_path}")
201+
continue
202+
203+
file_name = os.path.basename(file_path)
204+
for model, prefix in model_to_prefix.items():
205+
# Check for exact prefix match at start of filename
206+
if file_name.startswith(f"{prefix}_"):
207+
try:
208+
with open(file_path, "r") as f:
209+
data = json.load(f)
210+
if "ocr_data" in data:
211+
results[model] = pl.DataFrame(data["ocr_data"])
212+
else:
213+
results[model] = pl.DataFrame(data)
214+
break
215+
except Exception as e:
216+
logger.error(f"Error loading {model} results: {str(e)}")
217+
else:
218+
for model, prefix in model_to_prefix.items():
219+
model_dir = os.path.join(results_dir, prefix)
220+
if not os.path.exists(model_dir):
221+
logger.warning(f"No results directory found for model: {model}")
222+
continue
223+
224+
# Find files matching the exact prefix pattern
225+
json_files = glob.glob(os.path.join(model_dir, f"{prefix}_*.json"))
226+
if not json_files:
227+
logger.warning(f"No result files found for model: {model}")
228+
continue
229+
230+
latest_file = sorted(json_files, key=os.path.getmtime, reverse=True)[0]
231+
logger.info(f"Loading results for {model} from {latest_file}")
232+
233+
try:
234+
with open(latest_file, "r") as f:
235+
data = json.load(f)
236+
if "ocr_data" in data:
237+
results[model] = pl.DataFrame(data["ocr_data"])
238+
else:
239+
results[model] = pl.DataFrame(data)
240+
except Exception as e:
241+
logger.error(f"Error loading results for {model}: {str(e)}")
242+
243+
if not results:
244+
raise ValueError("No model results could be loaded. Run the batch pipeline first.")
245+
246+
return results

0 commit comments

Comments
 (0)