Skip to content

Commit 92e3fa0

Browse files
author
marwan37
committed
update utils
1 parent d6cc858 commit 92e3fa0

File tree

6 files changed

+673
-447
lines changed

6 files changed

+673
-447
lines changed

omni-reader/utils/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from .io_utils import (
3131
save_ocr_data_to_json,
3232
load_ocr_data_from_json,
33-
load_ground_truth_from_json,
3433
list_available_ground_truth_files,
3534
)
3635
from .model_configs import (
3736
MODEL_CONFIGS,
38-
ModelConfig,
37+
DEFAULT_MODEL,
3938
get_model_info,
39+
model_registry,
40+
ModelConfig,
41+
get_model_prefix,
4042
)
4143
from .extract_json import try_extract_json_from_response

omni-reader/utils/config.py

Lines changed: 133 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,15 @@
1515
# limitations under the License.
1616
"""Utilities for handling configuration."""
1717

18+
import glob
1819
import os
19-
from typing import Any, Dict
20+
from typing import Any, Dict, List, Optional
2021

2122
import yaml
2223

2324

2425
def load_config(config_path: str) -> Dict[str, Any]:
25-
"""Load configuration from YAML file.
26-
27-
Args:
28-
config_path: Path to YAML configuration file
29-
30-
Returns:
31-
Dictionary containing configuration
32-
33-
Raises:
34-
FileNotFoundError: If the configuration file does not exist
35-
ValueError: If the configuration file is not valid YAML
36-
"""
26+
"""Load configuration from YAML file."""
3727
if not os.path.isfile(config_path):
3828
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
3929

@@ -46,185 +36,161 @@ def load_config(config_path: str) -> Dict[str, Any]:
4636

4737

4838
def validate_config(config: Dict[str, Any]) -> None:
49-
"""Validate configuration.
50-
51-
Args:
52-
config: Dictionary containing configuration
53-
54-
Raises:
55-
ValueError: If the configuration is invalid
56-
"""
57-
# Validate top-level sections
58-
required_sections = ["input", "models", "ground_truth", "output"]
59-
for section in required_sections:
60-
if section not in config:
61-
raise ValueError(f"Missing required section '{section}' in configuration")
62-
63-
# Validate input section
64-
if not config["input"].get("image_paths") and not config["input"].get("image_folder"):
65-
raise ValueError("Either input.image_paths or input.image_folder must be provided")
66-
67-
if config["input"].get("image_folder") and not os.path.isdir(
68-
config["input"].get("image_folder")
69-
):
70-
raise ValueError(f"Image folder does not exist: {config['input'].get('image_folder')}")
71-
72-
# Validate ground truth configuration
73-
gt_source = config["ground_truth"].get("source", "none")
74-
if gt_source not in ["openai", "manual", "file", "none"]:
39+
"""Validate ZenML configuration."""
40+
# Validate required sections
41+
if "parameters" not in config:
42+
raise ValueError("Missing required 'parameters' section in configuration")
43+
44+
# Validate input parameters
45+
params = config.get("parameters", {})
46+
if not params.get("input_image_paths") and not params.get("input_image_folder"):
7547
raise ValueError(
76-
f"Invalid ground_truth.source: {gt_source}. Must be one of: openai, manual, file, none"
48+
"Either parameters.input_image_paths or parameters.input_image_folder must be provided"
7749
)
7850

79-
if gt_source == "manual" and not config["ground_truth"].get("texts"):
80-
raise ValueError(
81-
"When using ground_truth.source=manual, you must provide ground_truth.texts"
82-
)
51+
# Validate input folder exists
52+
image_folder = params.get("input_image_folder")
53+
if image_folder and not os.path.isdir(image_folder):
54+
raise ValueError(f"Image folder does not exist: {image_folder}")
55+
56+
# Validate steps configuration if present
57+
steps = config.get("steps", {})
8358

84-
if gt_source == "file" and not config["ground_truth"].get("file"):
85-
raise ValueError("When using ground_truth.source=file, you must provide ground_truth.file")
59+
# Validate model configuration
60+
if "ocr_processor" in steps:
61+
if "parameters" not in steps["ocr_processor"]:
62+
raise ValueError("Missing parameters section in steps.ocr_processor")
8663

87-
if gt_source == "file" and not os.path.isfile(config["ground_truth"].get("file")):
88-
raise ValueError(f"Ground truth file does not exist: {config['ground_truth'].get('file')}")
64+
if "selected_models" not in params:
65+
raise ValueError("Missing selected_models in parameters")
66+
67+
# Validate result saving configuration
68+
if "result_saver" in steps:
69+
if "parameters" not in steps["result_saver"]:
70+
raise ValueError("Missing parameters section in steps.result_saver")
8971

9072

9173
def override_config_with_cli_args(
9274
config: Dict[str, Any], cli_args: Dict[str, Any]
9375
) -> Dict[str, Any]:
94-
"""Override configuration with command-line arguments.
95-
96-
Args:
97-
config: Dictionary containing configuration
98-
cli_args: Dictionary containing command-line arguments
99-
100-
Returns:
101-
Updated configuration dictionary
102-
"""
76+
"""Override configuration with command-line arguments."""
10377
# Deep copy the config to avoid modifying the original
10478
config = {**config}
10579

80+
# Ensure parameters section exists
81+
if "parameters" not in config:
82+
config["parameters"] = {}
83+
84+
# Ensure steps section exists
85+
if "steps" not in config:
86+
config["steps"] = {}
87+
10688
# Override input configuration
10789
if cli_args.get("image_paths"):
108-
config["input"]["image_paths"] = cli_args["image_paths"]
90+
config["parameters"]["input_image_paths"] = cli_args["image_paths"]
10991
if cli_args.get("image_folder"):
110-
config["input"]["image_folder"] = cli_args["image_folder"]
92+
config["parameters"]["input_image_folder"] = cli_args["image_folder"]
11193

11294
# Override model configuration
113-
if cli_args.get("custom_prompt"):
114-
config["models"]["custom_prompt"] = cli_args["custom_prompt"]
95+
if cli_args.get("custom_prompt") and "ocr_processor" not in config["steps"]:
96+
config["steps"]["ocr_processor"] = {"parameters": {}}
11597

116-
# Override ground truth configuration
117-
if cli_args.get("ground_truth"):
118-
config["ground_truth"]["source"] = cli_args["ground_truth"]
119-
if cli_args.get("ground_truth_file"):
120-
config["ground_truth"]["file"] = cli_args["ground_truth_file"]
98+
if cli_args.get("custom_prompt"):
99+
config["steps"]["ocr_processor"]["parameters"]["custom_prompt"] = cli_args["custom_prompt"]
121100

122101
# Override output configuration
123-
if cli_args.get("save_ground_truth"):
124-
config["output"]["ground_truth"]["save"] = cli_args["save_ground_truth"]
125-
if cli_args.get("ground_truth_dir"):
126-
config["output"]["ground_truth"]["directory"] = cli_args["ground_truth_dir"]
127-
if cli_args.get("save_ocr_results"):
128-
config["output"]["ocr_results"]["save"] = cli_args["save_ocr_results"]
129-
if cli_args.get("ocr_results_dir"):
130-
config["output"]["ocr_results"]["directory"] = cli_args["ocr_results_dir"]
131-
if cli_args.get("save_visualization"):
132-
config["output"]["visualization"]["save"] = cli_args["save_visualization"]
133-
if cli_args.get("visualization_dir"):
134-
config["output"]["visualization"]["directory"] = cli_args["visualization_dir"]
102+
if cli_args.get("save_results"):
103+
if "result_saver" not in config["steps"]:
104+
config["steps"]["result_saver"] = {"parameters": {}}
105+
config["steps"]["result_saver"]["parameters"]["save_results"] = cli_args["save_results"]
106+
107+
if cli_args.get("results_directory"):
108+
if "result_saver" not in config["steps"]:
109+
config["steps"]["result_saver"] = {"parameters": {}}
110+
config["steps"]["result_saver"]["parameters"]["results_directory"] = cli_args[
111+
"results_directory"
112+
]
113+
114+
if cli_args.get("save_visualizations"):
115+
if "visualizer" not in config["steps"]:
116+
config["steps"]["visualizer"] = {"parameters": {}}
117+
config["steps"]["visualizer"]["parameters"]["save_visualizations"] = cli_args[
118+
"save_visualizations"
119+
]
120+
121+
if cli_args.get("visualization_directory"):
122+
if "visualizer" not in config["steps"]:
123+
config["steps"]["visualizer"] = {"parameters": {}}
124+
config["steps"]["visualizer"]["parameters"]["visualization_directory"] = cli_args[
125+
"visualization_directory"
126+
]
135127

136128
return config
137129

138130

139131
def print_config_summary(config: Dict[str, Any]) -> None:
140-
"""Print a summary of the configuration.
141-
142-
Args:
143-
config: Dictionary containing configuration
144-
"""
145-
print("\n===== OCR Comparison Pipeline Configuration =====")
146-
147-
# Input configuration
148-
print("\nInput:")
149-
if config["input"].get("image_paths"):
150-
print(f" • Using {len(config['input'].get('image_paths'))} specified image paths")
151-
if config["input"].get("image_folder"):
152-
print(f" • Searching for images in folder: {config['input'].get('image_folder')}")
153-
154-
# Model configuration
155-
print("\nModels:")
156-
if config["models"].get("custom_prompt"):
157-
print(f" • Using custom prompt: {config['models'].get('custom_prompt')[:50]}...")
158-
else:
159-
print(" • Using default prompts")
160-
161-
# Ground truth configuration
162-
print("\nGround Truth:")
163-
gt_source = config["ground_truth"].get("source", "none")
164-
print(f" • Source: {gt_source}")
165-
if gt_source == "file":
166-
print(f" • File: {config['ground_truth'].get('file')}")
167-
elif gt_source == "manual":
168-
print(f" • Manual texts: {len(config['ground_truth'].get('texts', []))} provided")
169-
170-
# Output configuration
171-
print("\nOutput:")
172-
if config["output"]["ground_truth"].get("save", False):
132+
"""Print a summary of the ZenML configuration."""
133+
print("\n===== OCR Pipeline Configuration =====")
134+
135+
# Get parameters
136+
params = config.get("parameters", {})
137+
138+
# Print pipeline mode
139+
mode = params.get("mode", "evaluation")
140+
print(f"Pipeline mode: {mode}")
141+
142+
# Print model information
143+
selected_models = params.get("selected_models", [])
144+
if selected_models:
145+
print(f"Selected models: {', '.join(selected_models)}")
146+
147+
# Print input information
148+
image_paths = params.get("input_image_paths", [])
149+
if image_paths:
150+
print(f"Input images: {len(image_paths)} specified")
151+
152+
image_folder = params.get("input_image_folder")
153+
if image_folder:
154+
print(f"Input folder: {image_folder}")
155+
156+
# Print ground truth information if available
157+
steps = config.get("steps", {})
158+
evaluator = steps.get("result_evaluator", {}).get("parameters", {})
159+
gt_folder = evaluator.get("ground_truth_folder")
160+
if gt_folder:
161+
print(f"Ground truth folder: {gt_folder}")
162+
gt_files = list_available_ground_truth_files(directory=gt_folder)
163+
print(f"Found {len(gt_files)} ground truth text files")
164+
165+
# Print output information
166+
result_saver = steps.get("result_saver", {}).get("parameters", {})
167+
if result_saver.get("save_results", False):
168+
print(f"Results will be saved to: {result_saver.get('results_directory', 'ocr_results')}")
169+
170+
visualizer = steps.get("visualizer", {}).get("parameters", {})
171+
if visualizer.get("save_visualizations", False):
173172
print(
174-
f" • Saving ground truth data to: {config['output']['ground_truth'].get('directory')}"
173+
f"Visualizations will be saved to: {visualizer.get('visualization_directory', 'visualizations')}"
175174
)
176-
if config["output"]["ocr_results"].get("save", False):
177-
print(f" • Saving OCR results to: {config['output']['ocr_results'].get('directory')}")
178-
if config["output"]["visualization"].get("save", False):
179-
print(f" • Saving visualization to: {config['output']['visualization'].get('directory')}")
180-
181-
print("\n================================================\n")
182-
183-
184-
def create_default_config() -> Dict[str, Any]:
185-
"""Create a default configuration.
186-
187-
Returns:
188-
Dictionary containing default configuration
189-
"""
190-
return {
191-
"input": {
192-
"image_paths": [],
193-
"image_folder": None,
194-
},
195-
"models": {
196-
"custom_prompt": None,
197-
},
198-
"ground_truth": {
199-
"source": "none",
200-
"texts": [],
201-
"file": None,
202-
},
203-
"output": {
204-
"ground_truth": {
205-
"save": False,
206-
"directory": "ground_truth",
207-
},
208-
"ocr_results": {
209-
"save": False,
210-
"directory": "ocr_results",
211-
},
212-
"visualization": {
213-
"save": False,
214-
"directory": "visualizations",
215-
},
216-
},
217-
}
218-
219-
220-
def save_config(config: Dict[str, Any], config_path: str) -> None:
221-
"""Save configuration to YAML file.
222-
223-
Args:
224-
config: Dictionary containing configuration
225-
config_path: Path to save the configuration file
226-
"""
227-
os.makedirs(os.path.dirname(os.path.abspath(config_path)), exist_ok=True)
228-
229-
with open(config_path, "w") as f:
230-
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
175+
176+
print("=" * 40 + "\n")
177+
178+
179+
def get_image_paths(directory: str) -> List[str]:
180+
"""Get all image paths from a directory."""
181+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
182+
image_paths = []
183+
184+
for ext in image_extensions:
185+
image_paths.extend(glob.glob(os.path.join(directory, ext)))
186+
187+
return sorted(image_paths)
188+
189+
190+
def list_available_ground_truth_files(directory: Optional[str] = None) -> List[str]:
191+
"""List available ground truth text files in the given directory."""
192+
if not directory or not os.path.isdir(directory):
193+
return []
194+
195+
text_files = glob.glob(os.path.join(directory, "*.txt"))
196+
return sorted(text_files)

omni-reader/utils/io_utils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,14 @@ def load_ocr_data_from_json(filepath: str) -> pl.DataFrame:
9292
return pl.DataFrame(ocr_data)
9393

9494

95-
def load_ground_truth_from_json(filepath: str) -> pl.DataFrame:
96-
"""Load ground truth data from a JSON file.
97-
98-
Args:
99-
filepath: Path to the ground truth JSON file
100-
101-
Returns:
102-
DataFrame containing ground truth data
103-
"""
104-
return load_ocr_data_from_json(filepath)
105-
106-
10795
def list_available_ground_truth_files(
108-
directory: str = "ocr_results/ground_truth", pattern: str = "gt_*.json"
96+
directory: str = "ground_truth_texts", pattern: str = "*.txt"
10997
) -> List[str]:
110-
"""List available ground truth files.
98+
"""List available ground truth text files.
11199
112100
Args:
113101
directory: Directory containing ground truth files
114-
pattern: Glob pattern to match files
102+
pattern: Glob pattern to match files (defaults to all text files)
115103
116104
Returns:
117105
List of paths to ground truth files
@@ -124,4 +112,4 @@ def list_available_ground_truth_files(
124112
# Find matching files
125113
files = glob.glob(path_pattern)
126114

127-
return sorted(files, reverse=True) # Sort by newest first
115+
return sorted(files) # Sort alphabetically

0 commit comments

Comments
 (0)