Skip to content

Commit a042d45

Browse files
committed
hcast model in pipeline
1 parent 9e7c243 commit a042d45

File tree

6 files changed

+82
-37
lines changed

6 files changed

+82
-37
lines changed

boem_conf/boem_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defaults:
22
- server: serenity
3-
- classification_model: finetune.yaml # or hierarchical.yaml
3+
- classification_model: finetune.yaml
44
- annotation: label_studio
55

66
hydra:

boem_conf/classification_model/hierarchical.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
backend: hierarchical
2+
enabled: False
23
force_train: False
34
# Optional: explicitly set checkpoint path; if empty we auto-discover under tamu_hcast/
45
checkpoint: /home/b.weinstein/BOEM/tamu_hcast/output/usgs_hcast_300_b256/best_checkpoint.pth
6+
# Optional path to CSV with species/genus/family labels
7+
label_csv: null
58
# Batch size for classifying crops after detection
69
batch_size: 64
710
workers: 4

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
# H-CAST requirements
2727
"timm",
2828
"opencv-contrib-python",
29-
#"deepforest",
29+
"deepforest<2.0",
3030
# DGL requires CUDA-specific wheels - will be resolved from find-links configured below
3131
"dgl",
3232
"packaging",
@@ -42,6 +42,7 @@ find-links = [
4242
]
4343
# Exclude
4444
override-dependencies = [
45-
"opencv-python; python_version < '0'"
45+
"opencv-python; python_version < '0'",
46+
"opencv-python-headless; python_version < '0'",
4647
]
4748

src/active_learning.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random
22
from src import detection
3+
from src import hierarchical
34

45
def human_review(predictions, min_detection_score=0.6, min_classification_score=0.5, confident_threshold=0.5):
56
"""
@@ -27,7 +28,7 @@ def human_review(predictions, min_detection_score=0.6, min_classification_score=
2728

2829
return confident_predictions, uncertain_predictions
2930

30-
def generate_pool_predictions(pool, patch_size=512, patch_overlap=0.1, min_score=0.1, model=None, batch_size=16, pool_limit=1000, crop_model=None):
31+
def generate_pool_predictions(pool, patch_size=512, patch_overlap=0.1, min_score=0.1, model=None, batch_size=16, pool_limit=1000, crop_model=None, hcast_model=None, image_dir=None, hcast_batch_size=None, hcast_workers=None):
3132
"""
3233
Generate predictions for the flight pool.
3334
@@ -41,9 +42,13 @@ def generate_pool_predictions(pool, patch_size=512, patch_overlap=0.1, min_score
4142
comet_logger (CometLogger, optional): A CometLogger object. Defaults to None.
4243
crop_model (bool, optional): A deepforest.model.CropModel object. Defaults to None.
4344
pool_limit (int, optional): The maximum number of images to consider. Defaults to 1000.
45+
hcast_model (optional): H-CAST hierarchical model wrapper. Defaults to None.
46+
image_dir (str, optional): Root directory where images are located. Required if hcast_model is provided.
47+
hcast_batch_size (int, optional): Batch size for H-CAST classification. Defaults to 64.
48+
hcast_workers (int, optional): Number of workers for H-CAST DataLoader. Defaults to 4.
4449
4550
Returns:
46-
pd.DataFrame: A DataFrame of predictions.
51+
pd.DataFrame: A DataFrame of predictions with both cropmodel and hcast columns (if hcast_model provided).
4752
"""
4853

4954
#subsample
@@ -63,6 +68,18 @@ def generate_pool_predictions(pool, patch_size=512, patch_overlap=0.1, min_score
6368

6469
preannotations = preannotations[preannotations["score"] >= min_score]
6570

71+
# Apply hierarchical classification if hcast_model is provided
72+
if hcast_model is not None:
73+
if image_dir is None:
74+
raise ValueError("image_dir is required when hcast_model is provided")
75+
preannotations = hierarchical.classify_dataframe(
76+
predictions=preannotations,
77+
image_dir=image_dir,
78+
model=hcast_model,
79+
batch_size=hcast_batch_size,
80+
num_workers=hcast_workers,
81+
)
82+
6683
return preannotations
6784

6885
def select_images(preannotations, strategy, n=10, target_labels=None, min_score=0.3):

src/hierarchical.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from typing import Optional, Tuple, List, Dict, Callable
33

44
import torch
5-
from torch.utils.data import Dataset, DataLoader
6-
from PIL import Image
7-
from src.hcast.cast_models import cast_deit_hier
85
import pandas as pd
96
import numpy as np
7+
from PIL import Image
108
import cv2
9+
10+
from torch.utils.data import Dataset, DataLoader
1111
from torchvision import transforms
1212
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13-
13+
from timm.models import create_model
1414

1515
def _infer_head_sizes_from_checkpoint(ckpt: Dict[str, torch.Tensor]) -> Tuple[int, Optional[int], Optional[int]]:
1616
species = None
@@ -167,9 +167,6 @@ def load_hcast_model(
167167
model_state_dict = checkpoint["state_dict"]
168168
else:
169169
model_state_dict = checkpoint
170-
171-
# Get the training arguments if available
172-
from timm.models import create_model
173170

174171
if 'args' in checkpoint:
175172
args = checkpoint['args']

src/pipeline.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def run(self):
139139
val_crop_image_dir = os.path.join(self.config.classification_model.val_crop_image_dir, self.comet_logger.experiment.id)
140140
os.makedirs(val_crop_image_dir, exist_ok=True)
141141

142-
if classification_backend == "deepforest":
142+
# Always load the finetune/cropmodel (load both models when both configs are present)
143+
trained_classification_model = None
144+
# Load cropmodel if finetune config is available (which it should be by default)
145+
if hasattr(self.config.classification_model, "checkpoint") and self.config.classification_model.checkpoint:
143146
# If there are no train annotations, turn off force training
144147
if all_training.xmin[all_training.xmin != 0].empty:
145148
self.config.classification_model.force_train = False
@@ -170,12 +173,23 @@ def run(self):
170173
workers=self.config.classification_model.workers,
171174
comet_logger=self.comet_logger)
172175
else:
173-
trained_classification_model = CropModel.load_from_checkpoint(self.config.classification_model.checkpoint )
174-
else:
175-
# Hierarchical backend (H-CAST). Load wrapper and classify crops post-detection.
176-
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
177-
hcast_checkpoint = getattr(self.config.classification_model, "checkpoint", None)
178-
hcast_wrapper = hierarchical.load_hcast_model(repo_root=repo_root, checkpoint_path=hcast_checkpoint)
176+
trained_classification_model = CropModel.load_from_checkpoint(self.config.classification_model.checkpoint)
177+
178+
# Load hierarchical model (H-CAST) if enabled
179+
hcast_model = None
180+
if classification_backend == "hierarchical":
181+
hcast_enabled = getattr(self.config.classification_model, "enabled", False)
182+
if hcast_enabled:
183+
hcast_checkpoint = getattr(self.config.classification_model, "checkpoint", None)
184+
if hcast_checkpoint:
185+
label_csv = getattr(self.config.classification_model, "label_csv", None)
186+
hcast_model = hierarchical.load_hcast_model(
187+
checkpoint_path=hcast_checkpoint,
188+
label_csv=label_csv
189+
)
190+
print(f"Loaded H-CAST model from {hcast_checkpoint}")
191+
else:
192+
print("H-CAST enabled but no checkpoint path provided, skipping hierarchical model")
179193

180194
pool = glob.glob(os.path.join(self.config.image_dir, "*.jpg")) # Get all images in the data directory
181195
pool = [image for image in pool if not image.endswith('.csv')]
@@ -189,6 +203,13 @@ def run(self):
189203
else:
190204
pool = random.sample(pool, 10)
191205

206+
# Get hierarchical model parameters from config (hierarchical.yaml)
207+
hcast_batch_size = None
208+
hcast_workers = None
209+
if classification_backend == "hierarchical" and getattr(self.config.classification_model, "enabled", False):
210+
hcast_batch_size = self.config.classification_model.batch_size
211+
hcast_workers = self.config.classification_model.workers
212+
192213
flightline_predictions = generate_pool_predictions(
193214
pool=pool,
194215
pool_limit=self.config.active_learning.pool_limit,
@@ -198,6 +219,10 @@ def run(self):
198219
model=trained_detection_model,
199220
batch_size=self.config.predict.batch_size,
200221
crop_model=trained_classification_model,
222+
hcast_model=hcast_model,
223+
image_dir=self.config.image_dir if hcast_model is not None else None,
224+
hcast_batch_size=hcast_batch_size,
225+
hcast_workers=hcast_workers,
201226
)
202227

203228
if flightline_predictions is None:
@@ -209,24 +234,26 @@ def run(self):
209234
if self.existing_validation is None:
210235
print("No validation annotations, skipping evaluation")
211236
else:
212-
evaluation_annotations = self.existing_validation.copy(deep=True)
213-
evaluation_predictions = flightline_predictions[flightline_predictions.image_path.isin(self.existing_validation.image_path)]
214-
215-
216-
label_dict = trained_classification_model.label_dict
217-
218-
pipeline_monitor = PipelineEvaluation(
219-
predictions=evaluation_predictions,
220-
annotations=evaluation_annotations,
221-
classification_label_dict=label_dict,
222-
**self.config.pipeline_evaluation)
223-
224-
performance = pipeline_monitor.evaluate()
225-
self.comet_logger.experiment.log_metrics(performance)
226-
227-
if pipeline_monitor.check_success():
228-
print("Pipeline performance is satisfactory, exiting")
229-
return None
237+
if trained_classification_model is None:
238+
print("No classification model available, skipping evaluation")
239+
else:
240+
evaluation_annotations = self.existing_validation.copy(deep=True)
241+
evaluation_predictions = flightline_predictions[flightline_predictions.image_path.isin(self.existing_validation.image_path)]
242+
243+
label_dict = trained_classification_model.label_dict
244+
245+
pipeline_monitor = PipelineEvaluation(
246+
predictions=evaluation_predictions,
247+
annotations=evaluation_annotations,
248+
classification_label_dict=label_dict,
249+
**self.config.pipeline_evaluation)
250+
251+
performance = pipeline_monitor.evaluate()
252+
self.comet_logger.experiment.log_metrics(performance)
253+
254+
if pipeline_monitor.check_success():
255+
print("Pipeline performance is satisfactory, exiting")
256+
return None
230257

231258
test_preannotations = flightline_predictions[~flightline_predictions.image_path.isin(self.existing_images)]
232259
test_images_to_annotate, preannotations = select_images(

0 commit comments

Comments
 (0)