1515# limitations under the License.
1616"""Utilities for handling configuration."""
1717
18+ import glob
1819import os
19- from typing import Any , Dict
20+ from typing import Any , Dict , List , Optional
2021
2122import yaml
2223
2324
2425def load_config (config_path : str ) -> Dict [str , Any ]:
25- """Load configuration from YAML file.
26-
27- Args:
28- config_path: Path to YAML configuration file
29-
30- Returns:
31- Dictionary containing configuration
32-
33- Raises:
34- FileNotFoundError: If the configuration file does not exist
35- ValueError: If the configuration file is not valid YAML
36- """
26+ """Load configuration from YAML file."""
3727 if not os .path .isfile (config_path ):
3828 raise FileNotFoundError (f"Configuration file does not exist: { config_path } " )
3929
@@ -46,185 +36,161 @@ def load_config(config_path: str) -> Dict[str, Any]:
4636
4737
4838def validate_config (config : Dict [str , Any ]) -> None :
49- """Validate configuration.
50-
51- Args:
52- config: Dictionary containing configuration
53-
54- Raises:
55- ValueError: If the configuration is invalid
56- """
57- # Validate top-level sections
58- required_sections = ["input" , "models" , "ground_truth" , "output" ]
59- for section in required_sections :
60- if section not in config :
61- raise ValueError (f"Missing required section '{ section } ' in configuration" )
62-
63- # Validate input section
64- if not config ["input" ].get ("image_paths" ) and not config ["input" ].get ("image_folder" ):
65- raise ValueError ("Either input.image_paths or input.image_folder must be provided" )
66-
67- if config ["input" ].get ("image_folder" ) and not os .path .isdir (
68- config ["input" ].get ("image_folder" )
69- ):
70- raise ValueError (f"Image folder does not exist: { config ['input' ].get ('image_folder' )} " )
71-
72- # Validate ground truth configuration
73- gt_source = config ["ground_truth" ].get ("source" , "none" )
74- if gt_source not in ["openai" , "manual" , "file" , "none" ]:
39+ """Validate ZenML configuration."""
40+ # Validate required sections
41+ if "parameters" not in config :
42+ raise ValueError ("Missing required 'parameters' section in configuration" )
43+
44+ # Validate input parameters
45+ params = config .get ("parameters" , {})
46+ if not params .get ("input_image_paths" ) and not params .get ("input_image_folder" ):
7547 raise ValueError (
76- f"Invalid ground_truth.source: { gt_source } . Must be one of: openai, manual, file, none "
48+ "Either parameters.input_image_paths or parameters.input_image_folder must be provided "
7749 )
7850
79- if gt_source == "manual" and not config ["ground_truth" ].get ("texts" ):
80- raise ValueError (
81- "When using ground_truth.source=manual, you must provide ground_truth.texts"
82- )
51+ # Validate input folder exists
52+ image_folder = params .get ("input_image_folder" )
53+ if image_folder and not os .path .isdir (image_folder ):
54+ raise ValueError (f"Image folder does not exist: { image_folder } " )
55+
56+ # Validate steps configuration if present
57+ steps = config .get ("steps" , {})
8358
84- if gt_source == "file" and not config ["ground_truth" ].get ("file" ):
85- raise ValueError ("When using ground_truth.source=file, you must provide ground_truth.file" )
59+ # Validate model configuration
60+ if "ocr_processor" in steps :
61+ if "parameters" not in steps ["ocr_processor" ]:
62+ raise ValueError ("Missing parameters section in steps.ocr_processor" )
8663
87- if gt_source == "file" and not os .path .isfile (config ["ground_truth" ].get ("file" )):
88- raise ValueError (f"Ground truth file does not exist: { config ['ground_truth' ].get ('file' )} " )
64+ if "selected_models" not in params :
65+ raise ValueError ("Missing selected_models in parameters" )
66+
67+ # Validate result saving configuration
68+ if "result_saver" in steps :
69+ if "parameters" not in steps ["result_saver" ]:
70+ raise ValueError ("Missing parameters section in steps.result_saver" )
8971
9072
9173def override_config_with_cli_args (
9274 config : Dict [str , Any ], cli_args : Dict [str , Any ]
9375) -> Dict [str , Any ]:
94- """Override configuration with command-line arguments.
95-
96- Args:
97- config: Dictionary containing configuration
98- cli_args: Dictionary containing command-line arguments
99-
100- Returns:
101- Updated configuration dictionary
102- """
76+ """Override configuration with command-line arguments."""
10377 # Deep copy the config to avoid modifying the original
10478 config = {** config }
10579
80+ # Ensure parameters section exists
81+ if "parameters" not in config :
82+ config ["parameters" ] = {}
83+
84+ # Ensure steps section exists
85+ if "steps" not in config :
86+ config ["steps" ] = {}
87+
10688 # Override input configuration
10789 if cli_args .get ("image_paths" ):
108- config ["input " ]["image_paths " ] = cli_args ["image_paths" ]
90+ config ["parameters " ]["input_image_paths " ] = cli_args ["image_paths" ]
10991 if cli_args .get ("image_folder" ):
110- config ["input " ]["image_folder " ] = cli_args ["image_folder" ]
92+ config ["parameters " ]["input_image_folder " ] = cli_args ["image_folder" ]
11193
11294 # Override model configuration
113- if cli_args .get ("custom_prompt" ):
114- config ["models " ]["custom_prompt " ] = cli_args [ "custom_prompt" ]
95+ if cli_args .get ("custom_prompt" ) and "ocr_processor" not in config [ "steps" ] :
96+ config ["steps " ]["ocr_processor " ] = { "parameters" : {}}
11597
116- # Override ground truth configuration
117- if cli_args .get ("ground_truth" ):
118- config ["ground_truth" ]["source" ] = cli_args ["ground_truth" ]
119- if cli_args .get ("ground_truth_file" ):
120- config ["ground_truth" ]["file" ] = cli_args ["ground_truth_file" ]
98+ if cli_args .get ("custom_prompt" ):
99+ config ["steps" ]["ocr_processor" ]["parameters" ]["custom_prompt" ] = cli_args ["custom_prompt" ]
121100
122101 # Override output configuration
123- if cli_args .get ("save_ground_truth" ):
124- config ["output" ]["ground_truth" ]["save" ] = cli_args ["save_ground_truth" ]
125- if cli_args .get ("ground_truth_dir" ):
126- config ["output" ]["ground_truth" ]["directory" ] = cli_args ["ground_truth_dir" ]
127- if cli_args .get ("save_ocr_results" ):
128- config ["output" ]["ocr_results" ]["save" ] = cli_args ["save_ocr_results" ]
129- if cli_args .get ("ocr_results_dir" ):
130- config ["output" ]["ocr_results" ]["directory" ] = cli_args ["ocr_results_dir" ]
131- if cli_args .get ("save_visualization" ):
132- config ["output" ]["visualization" ]["save" ] = cli_args ["save_visualization" ]
133- if cli_args .get ("visualization_dir" ):
134- config ["output" ]["visualization" ]["directory" ] = cli_args ["visualization_dir" ]
102+ if cli_args .get ("save_results" ):
103+ if "result_saver" not in config ["steps" ]:
104+ config ["steps" ]["result_saver" ] = {"parameters" : {}}
105+ config ["steps" ]["result_saver" ]["parameters" ]["save_results" ] = cli_args ["save_results" ]
106+
107+ if cli_args .get ("results_directory" ):
108+ if "result_saver" not in config ["steps" ]:
109+ config ["steps" ]["result_saver" ] = {"parameters" : {}}
110+ config ["steps" ]["result_saver" ]["parameters" ]["results_directory" ] = cli_args [
111+ "results_directory"
112+ ]
113+
114+ if cli_args .get ("save_visualizations" ):
115+ if "visualizer" not in config ["steps" ]:
116+ config ["steps" ]["visualizer" ] = {"parameters" : {}}
117+ config ["steps" ]["visualizer" ]["parameters" ]["save_visualizations" ] = cli_args [
118+ "save_visualizations"
119+ ]
120+
121+ if cli_args .get ("visualization_directory" ):
122+ if "visualizer" not in config ["steps" ]:
123+ config ["steps" ]["visualizer" ] = {"parameters" : {}}
124+ config ["steps" ]["visualizer" ]["parameters" ]["visualization_directory" ] = cli_args [
125+ "visualization_directory"
126+ ]
135127
136128 return config
137129
138130
139131def print_config_summary (config : Dict [str , Any ]) -> None :
140- """Print a summary of the configuration.
141-
142- Args:
143- config: Dictionary containing configuration
144- """
145- print ("\n ===== OCR Comparison Pipeline Configuration =====" )
146-
147- # Input configuration
148- print ("\n Input:" )
149- if config ["input" ].get ("image_paths" ):
150- print (f" • Using { len (config ['input' ].get ('image_paths' ))} specified image paths" )
151- if config ["input" ].get ("image_folder" ):
152- print (f" • Searching for images in folder: { config ['input' ].get ('image_folder' )} " )
153-
154- # Model configuration
155- print ("\n Models:" )
156- if config ["models" ].get ("custom_prompt" ):
157- print (f" • Using custom prompt: { config ['models' ].get ('custom_prompt' )[:50 ]} ..." )
158- else :
159- print (" • Using default prompts" )
160-
161- # Ground truth configuration
162- print ("\n Ground Truth:" )
163- gt_source = config ["ground_truth" ].get ("source" , "none" )
164- print (f" • Source: { gt_source } " )
165- if gt_source == "file" :
166- print (f" • File: { config ['ground_truth' ].get ('file' )} " )
167- elif gt_source == "manual" :
168- print (f" • Manual texts: { len (config ['ground_truth' ].get ('texts' , []))} provided" )
169-
170- # Output configuration
171- print ("\n Output:" )
172- if config ["output" ]["ground_truth" ].get ("save" , False ):
132+ """Print a summary of the ZenML configuration."""
133+ print ("\n ===== OCR Pipeline Configuration =====" )
134+
135+ # Get parameters
136+ params = config .get ("parameters" , {})
137+
138+ # Print pipeline mode
139+ mode = params .get ("mode" , "evaluation" )
140+ print (f"Pipeline mode: { mode } " )
141+
142+ # Print model information
143+ selected_models = params .get ("selected_models" , [])
144+ if selected_models :
145+ print (f"Selected models: { ', ' .join (selected_models )} " )
146+
147+ # Print input information
148+ image_paths = params .get ("input_image_paths" , [])
149+ if image_paths :
150+ print (f"Input images: { len (image_paths )} specified" )
151+
152+ image_folder = params .get ("input_image_folder" )
153+ if image_folder :
154+ print (f"Input folder: { image_folder } " )
155+
156+ # Print ground truth information if available
157+ steps = config .get ("steps" , {})
158+ evaluator = steps .get ("result_evaluator" , {}).get ("parameters" , {})
159+ gt_folder = evaluator .get ("ground_truth_folder" )
160+ if gt_folder :
161+ print (f"Ground truth folder: { gt_folder } " )
162+ gt_files = list_available_ground_truth_files (directory = gt_folder )
163+ print (f"Found { len (gt_files )} ground truth text files" )
164+
165+ # Print output information
166+ result_saver = steps .get ("result_saver" , {}).get ("parameters" , {})
167+ if result_saver .get ("save_results" , False ):
168+ print (f"Results will be saved to: { result_saver .get ('results_directory' , 'ocr_results' )} " )
169+
170+ visualizer = steps .get ("visualizer" , {}).get ("parameters" , {})
171+ if visualizer .get ("save_visualizations" , False ):
173172 print (
174- f" • Saving ground truth data to: { config [ 'output' ][ 'ground_truth' ] .get ('directory ' )} "
173+ f"Visualizations will be saved to: { visualizer .get ('visualization_directory' , 'visualizations ' )} "
175174 )
176- if config ["output" ]["ocr_results" ].get ("save" , False ):
177- print (f" • Saving OCR results to: { config ['output' ]['ocr_results' ].get ('directory' )} " )
178- if config ["output" ]["visualization" ].get ("save" , False ):
179- print (f" • Saving visualization to: { config ['output' ]['visualization' ].get ('directory' )} " )
180-
181- print ("\n ================================================\n " )
182-
183-
184- def create_default_config () -> Dict [str , Any ]:
185- """Create a default configuration.
186-
187- Returns:
188- Dictionary containing default configuration
189- """
190- return {
191- "input" : {
192- "image_paths" : [],
193- "image_folder" : None ,
194- },
195- "models" : {
196- "custom_prompt" : None ,
197- },
198- "ground_truth" : {
199- "source" : "none" ,
200- "texts" : [],
201- "file" : None ,
202- },
203- "output" : {
204- "ground_truth" : {
205- "save" : False ,
206- "directory" : "ground_truth" ,
207- },
208- "ocr_results" : {
209- "save" : False ,
210- "directory" : "ocr_results" ,
211- },
212- "visualization" : {
213- "save" : False ,
214- "directory" : "visualizations" ,
215- },
216- },
217- }
218-
219-
220- def save_config (config : Dict [str , Any ], config_path : str ) -> None :
221- """Save configuration to YAML file.
222-
223- Args:
224- config: Dictionary containing configuration
225- config_path: Path to save the configuration file
226- """
227- os .makedirs (os .path .dirname (os .path .abspath (config_path )), exist_ok = True )
228-
229- with open (config_path , "w" ) as f :
230- yaml .dump (config , f , default_flow_style = False , sort_keys = False )
175+
176+ print ("=" * 40 + "\n " )
177+
178+
179+ def get_image_paths (directory : str ) -> List [str ]:
180+ """Get all image paths from a directory."""
181+ image_extensions = ["*.jpg" , "*.jpeg" , "*.png" , "*.webp" ]
182+ image_paths = []
183+
184+ for ext in image_extensions :
185+ image_paths .extend (glob .glob (os .path .join (directory , ext )))
186+
187+ return sorted (image_paths )
188+
189+
190+ def list_available_ground_truth_files (directory : Optional [str ] = None ) -> List [str ]:
191+ """List available ground truth text files in the given directory."""
192+ if not directory or not os .path .isdir (directory ):
193+ return []
194+
195+ text_files = glob .glob (os .path .join (directory , "*.txt" ))
196+ return sorted (text_files )
0 commit comments