@@ -139,7 +139,10 @@ def run(self):
139139 val_crop_image_dir = os .path .join (self .config .classification_model .val_crop_image_dir , self .comet_logger .experiment .id )
140140 os .makedirs (val_crop_image_dir , exist_ok = True )
141141
142- if classification_backend == "deepforest" :
142+ # Always load the finetune/cropmodel (load both models when both configs are present)
143+ trained_classification_model = None
144+ # Load cropmodel if finetune config is available (which it should be by default)
145+ if hasattr (self .config .classification_model , "checkpoint" ) and self .config .classification_model .checkpoint :
143146 # If there are no train annotations, turn off force training
144147 if all_training .xmin [all_training .xmin != 0 ].empty :
145148 self .config .classification_model .force_train = False
@@ -170,12 +173,23 @@ def run(self):
170173 workers = self .config .classification_model .workers ,
171174 comet_logger = self .comet_logger )
172175 else :
173- trained_classification_model = CropModel .load_from_checkpoint (self .config .classification_model .checkpoint )
174- else :
175- # Hierarchical backend (H-CAST). Load wrapper and classify crops post-detection.
176- repo_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." ))
177- hcast_checkpoint = getattr (self .config .classification_model , "checkpoint" , None )
178- hcast_wrapper = hierarchical .load_hcast_model (repo_root = repo_root , checkpoint_path = hcast_checkpoint )
176+ trained_classification_model = CropModel .load_from_checkpoint (self .config .classification_model .checkpoint )
177+
178+ # Load hierarchical model (H-CAST) if enabled
179+ hcast_model = None
180+ if classification_backend == "hierarchical" :
181+ hcast_enabled = getattr (self .config .classification_model , "enabled" , False )
182+ if hcast_enabled :
183+ hcast_checkpoint = getattr (self .config .classification_model , "checkpoint" , None )
184+ if hcast_checkpoint :
185+ label_csv = getattr (self .config .classification_model , "label_csv" , None )
186+ hcast_model = hierarchical .load_hcast_model (
187+ checkpoint_path = hcast_checkpoint ,
188+ label_csv = label_csv
189+ )
190+ print (f"Loaded H-CAST model from { hcast_checkpoint } " )
191+ else :
192+ print ("H-CAST enabled but no checkpoint path provided, skipping hierarchical model" )
179193
180194 pool = glob .glob (os .path .join (self .config .image_dir , "*.jpg" )) # Get all images in the data directory
181195 pool = [image for image in pool if not image .endswith ('.csv' )]
@@ -189,6 +203,13 @@ def run(self):
189203 else :
190204 pool = random .sample (pool , 10 )
191205
206+ # Get hierarchical model parameters from config (hierarchical.yaml)
207+ hcast_batch_size = None
208+ hcast_workers = None
209+ if classification_backend == "hierarchical" and getattr (self .config .classification_model , "enabled" , False ):
210+ hcast_batch_size = self .config .classification_model .batch_size
211+ hcast_workers = self .config .classification_model .workers
212+
192213 flightline_predictions = generate_pool_predictions (
193214 pool = pool ,
194215 pool_limit = self .config .active_learning .pool_limit ,
@@ -198,6 +219,10 @@ def run(self):
198219 model = trained_detection_model ,
199220 batch_size = self .config .predict .batch_size ,
200221 crop_model = trained_classification_model ,
222+ hcast_model = hcast_model ,
223+ image_dir = self .config .image_dir if hcast_model is not None else None ,
224+ hcast_batch_size = hcast_batch_size ,
225+ hcast_workers = hcast_workers ,
201226 )
202227
203228 if flightline_predictions is None :
@@ -209,24 +234,26 @@ def run(self):
209234 if self .existing_validation is None :
210235 print ("No validation annotations, skipping evaluation" )
211236 else :
212- evaluation_annotations = self .existing_validation .copy (deep = True )
213- evaluation_predictions = flightline_predictions [flightline_predictions .image_path .isin (self .existing_validation .image_path )]
214-
215-
216- label_dict = trained_classification_model .label_dict
217-
218- pipeline_monitor = PipelineEvaluation (
219- predictions = evaluation_predictions ,
220- annotations = evaluation_annotations ,
221- classification_label_dict = label_dict ,
222- ** self .config .pipeline_evaluation )
223-
224- performance = pipeline_monitor .evaluate ()
225- self .comet_logger .experiment .log_metrics (performance )
226-
227- if pipeline_monitor .check_success ():
228- print ("Pipeline performance is satisfactory, exiting" )
229- return None
237+ if trained_classification_model is None :
238+ print ("No classification model available, skipping evaluation" )
239+ else :
240+ evaluation_annotations = self .existing_validation .copy (deep = True )
241+ evaluation_predictions = flightline_predictions [flightline_predictions .image_path .isin (self .existing_validation .image_path )]
242+
243+ label_dict = trained_classification_model .label_dict
244+
245+ pipeline_monitor = PipelineEvaluation (
246+ predictions = evaluation_predictions ,
247+ annotations = evaluation_annotations ,
248+ classification_label_dict = label_dict ,
249+ ** self .config .pipeline_evaluation )
250+
251+ performance = pipeline_monitor .evaluate ()
252+ self .comet_logger .experiment .log_metrics (performance )
253+
254+ if pipeline_monitor .check_success ():
255+ print ("Pipeline performance is satisfactory, exiting" )
256+ return None
230257
231258 test_preannotations = flightline_predictions [~ flightline_predictions .image_path .isin (self .existing_images )]
232259 test_images_to_annotate , preannotations = select_images (
0 commit comments