|
| 1 | +"""Convert DeepForest bounding box predictions to polygons using SAM2.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import logging |
| 5 | +import os |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | +import torch |
| 11 | +from PIL import Image |
| 12 | +from shapely import wkt |
| 13 | +from shapely.geometry import Polygon |
| 14 | +from tqdm import tqdm |
| 15 | +from transformers import Sam2Model, Sam2Processor |
| 16 | + |
| 17 | +from deepforest.utilities import mask_to_polygon |
| 18 | +from deepforest.visualize import plot_results |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +def load_sam2_model(model_name: str, device: str): |
| 24 | + """Load SAM2 model and processor from HuggingFace. |
| 25 | +
|
| 26 | + Args: |
| 27 | + model_name: Name of the SAM2 model on HuggingFace |
| 28 | + device: Device to load model on ('cuda', 'mps', or 'cpu') |
| 29 | +
|
| 30 | + Returns: |
| 31 | + Tuple of (model, processor) |
| 32 | + """ |
| 33 | + processor = Sam2Processor.from_pretrained(model_name) |
| 34 | + model = Sam2Model.from_pretrained(model_name) |
| 35 | + model = model.to(device) |
| 36 | + return model, processor |
| 37 | + |
| 38 | + |
| 39 | +def process_image_group( |
| 40 | + image_path: str, |
| 41 | + detections: pd.DataFrame, |
| 42 | + model, |
| 43 | + processor, |
| 44 | + device: str, |
| 45 | + image_root: str = "", |
| 46 | + box_batch_size: int = 32, |
| 47 | + mask_threshold: float = 0.5, |
| 48 | + iou_threshold: float = 0.5, |
| 49 | + viz_output_dir: str = None, |
| 50 | +) -> list: |
| 51 | + """Process all detections for a single image. |
| 52 | +
|
| 53 | + Args: |
| 54 | + image_path: Path to the image file |
| 55 | + detections: DataFrame of detections for this image |
| 56 | + model: SAM2 model |
| 57 | + processor: SAM2 processor |
| 58 | + device: Device to run inference on |
| 59 | + image_root: Root directory to prepend to image_path if needed |
| 60 | + box_batch_size: Maximum number of boxes to process per forward pass |
| 61 | + mask_threshold: Threshold for binarizing SAM2 mask outputs |
| 62 | + iou_threshold: Minimum IoU score to accept a polygon |
| 63 | + viz_output_dir: Directory to save visualizations (if not None) |
| 64 | +
|
| 65 | + Returns: |
| 66 | + List of WKT polygon strings |
| 67 | + """ |
| 68 | + full_path = os.path.join(image_root, image_path) if image_root else image_path |
| 69 | + image = Image.open(full_path).convert("RGB") |
| 70 | + |
| 71 | + boxes = detections[["xmin", "ymin", "xmax", "ymax"]].values.tolist() |
| 72 | + |
| 73 | + all_polygons = [] |
| 74 | + for i in range(0, len(boxes), box_batch_size): |
| 75 | + box_chunk = boxes[i : i + box_batch_size] |
| 76 | + input_boxes = [box_chunk] |
| 77 | + |
| 78 | + inputs = processor(images=image, input_boxes=input_boxes, return_tensors="pt").to( |
| 79 | + device |
| 80 | + ) |
| 81 | + |
| 82 | + with torch.no_grad(): |
| 83 | + outputs = model(**inputs) |
| 84 | + |
| 85 | + masks = processor.post_process_masks( |
| 86 | + outputs.pred_masks.cpu(), |
| 87 | + inputs["original_sizes"], |
| 88 | + binarize=False, |
| 89 | + mask_interpolation_mode="nearest", |
| 90 | + )[0] |
| 91 | + |
| 92 | + iou_scores = outputs.iou_scores.cpu() |
| 93 | + |
| 94 | + for i, mask_set in enumerate(masks): |
| 95 | + best_idx = iou_scores[0, i].argmax().item() |
| 96 | + best_iou = iou_scores[0, i, best_idx].item() |
| 97 | + |
| 98 | + if best_iou < iou_threshold: |
| 99 | + all_polygons.append(Polygon().wkt) |
| 100 | + continue |
| 101 | + |
| 102 | + best_mask = mask_set[best_idx] |
| 103 | + mask_np = (best_mask.numpy() > mask_threshold).astype(np.uint8) |
| 104 | + polygon = mask_to_polygon(mask_np) |
| 105 | + all_polygons.append(polygon.wkt) |
| 106 | + |
| 107 | + if viz_output_dir is not None: |
| 108 | + viz_df = detections.copy() |
| 109 | + viz_df["polygon_geometry"] = all_polygons |
| 110 | + viz_df["geometry"] = viz_df["polygon_geometry"].apply( |
| 111 | + lambda x: wkt.loads(x) if pd.notna(x) else None |
| 112 | + ) |
| 113 | + viz_df = viz_df[ |
| 114 | + viz_df["geometry"].apply(lambda x: x is not None and not x.is_empty) |
| 115 | + ] |
| 116 | + |
| 117 | + if len(viz_df) > 0: |
| 118 | + if "label" not in viz_df.columns: |
| 119 | + viz_df["label"] = "Tree" |
| 120 | + if "score" not in viz_df.columns: |
| 121 | + viz_df["score"] = 1.0 |
| 122 | + |
| 123 | + full_path = os.path.join(image_root, image_path) if image_root else image_path |
| 124 | + with Image.open(full_path) as img: |
| 125 | + width, height = img.size |
| 126 | + |
| 127 | + image_name = Path(image_path).stem |
| 128 | + viz_path = os.path.join(viz_output_dir, f"{image_name}_polygons.png") |
| 129 | + plot_results( |
| 130 | + results=viz_df, |
| 131 | + image=full_path, |
| 132 | + savedir=os.path.dirname(viz_path), |
| 133 | + basename=os.path.splitext(os.path.basename(viz_path))[0], |
| 134 | + height=height, |
| 135 | + width=width, |
| 136 | + show=False, |
| 137 | + ) |
| 138 | + |
| 139 | + return all_polygons |
| 140 | + |
| 141 | + |
| 142 | +def convert_boxes_to_polygons( |
| 143 | + input_csv: str, |
| 144 | + output_csv: str, |
| 145 | + model_name: str = "facebook/sam2.1-hiera-small", |
| 146 | + box_batch_size: int = 32, |
| 147 | + image_root: str = "", |
| 148 | + visualize: bool = False, |
| 149 | + viz_output_dir: str = ".", |
| 150 | + mask_threshold: float = 0.5, |
| 151 | + iou_threshold: float = 0.5, |
| 152 | + device: str = None, |
| 153 | +) -> None: |
| 154 | + """Convert DeepForest bounding boxes to polygons using SAM2. |
| 155 | +
|
| 156 | + Args: |
| 157 | + input_csv: Path to input CSV with DeepForest predictions |
| 158 | + output_csv: Path to save output CSV with polygons |
| 159 | + model_name: HuggingFace model name for SAM2 |
| 160 | + box_batch_size: Maximum number of boxes to process per forward pass |
| 161 | + image_root: Root directory to prepend to image paths in CSV |
| 162 | + visualize: Whether to create visualization images |
| 163 | + viz_output_dir: Directory to save visualization images |
| 164 | + mask_threshold: Threshold for binarizing SAM2 mask outputs |
| 165 | + iou_threshold: Minimum IoU score to accept a polygon |
| 166 | + device: Device to use ('cuda', 'mps', or 'cpu'). Auto-detects if None. |
| 167 | + """ |
| 168 | + df = pd.read_csv(input_csv) |
| 169 | + |
| 170 | + required_cols = ["xmin", "ymin", "xmax", "ymax", "image_path"] |
| 171 | + missing_cols = [col for col in required_cols if col not in df.columns] |
| 172 | + if missing_cols: |
| 173 | + raise ValueError(f"Missing required columns: {missing_cols}") |
| 174 | + |
| 175 | + if device is None: |
| 176 | + if torch.backends.mps.is_available(): |
| 177 | + device = "mps" |
| 178 | + elif torch.cuda.is_available(): |
| 179 | + device = "cuda" |
| 180 | + else: |
| 181 | + device = "cpu" |
| 182 | + |
| 183 | + logger.info("Using device: %s", device) |
| 184 | + logger.info("Loading SAM2 model: %s", model_name) |
| 185 | + model, processor = load_sam2_model(model_name, device) |
| 186 | + |
| 187 | + grouped = df.groupby("image_path") |
| 188 | + total_images = len(grouped) |
| 189 | + |
| 190 | + all_polygons = [] |
| 191 | + |
| 192 | + if visualize: |
| 193 | + os.makedirs(viz_output_dir, exist_ok=True) |
| 194 | + |
| 195 | + for image_path, group in tqdm(grouped, desc="Processing images", total=total_images): |
| 196 | + polygons = process_image_group( |
| 197 | + image_path, |
| 198 | + group, |
| 199 | + model, |
| 200 | + processor, |
| 201 | + device, |
| 202 | + image_root, |
| 203 | + box_batch_size, |
| 204 | + mask_threshold, |
| 205 | + iou_threshold, |
| 206 | + viz_output_dir=viz_output_dir if visualize else None, |
| 207 | + ) |
| 208 | + all_polygons.extend(polygons) |
| 209 | + |
| 210 | + df["polygon_geometry"] = all_polygons |
| 211 | + |
| 212 | + os.makedirs(os.path.dirname(output_csv), exist_ok=True) |
| 213 | + df.to_csv(output_csv, index=False) |
| 214 | + logger.info("Saved results to %s", output_csv) |
| 215 | + |
| 216 | + |
| 217 | +def main(): |
| 218 | + """CLI entrypoint for polygon conversion.""" |
| 219 | + parser = argparse.ArgumentParser( |
| 220 | + description="Convert DeepForest bounding boxes to polygons using SAM2" |
| 221 | + ) |
| 222 | + parser.add_argument("input", help="Path to input CSV with DeepForest predictions") |
| 223 | + parser.add_argument( |
| 224 | + "-o", |
| 225 | + "--output", |
| 226 | + help="Path to output CSV (default: input with '_polygons' suffix)", |
| 227 | + ) |
| 228 | + parser.add_argument( |
| 229 | + "--model", |
| 230 | + default="facebook/sam2.1-hiera-small", |
| 231 | + help="SAM2 model name from HuggingFace (default: facebook/sam2.1-hiera-small)", |
| 232 | + ) |
| 233 | + parser.add_argument( |
| 234 | + "--box-batch", |
| 235 | + type=int, |
| 236 | + default=32, |
| 237 | + help="Maximum number of boxes to process per forward pass (default: 32)", |
| 238 | + ) |
| 239 | + parser.add_argument( |
| 240 | + "--image-root", |
| 241 | + default="", |
| 242 | + help="Root directory to prepend to image paths in CSV", |
| 243 | + ) |
| 244 | + parser.add_argument( |
| 245 | + "--visualize", |
| 246 | + action="store_true", |
| 247 | + help="Create visualization images with polygons overlaid", |
| 248 | + ) |
| 249 | + parser.add_argument( |
| 250 | + "--viz-output-dir", |
| 251 | + default=".", |
| 252 | + help="Directory to save visualization images (default: current directory)", |
| 253 | + ) |
| 254 | + parser.add_argument( |
| 255 | + "--mask-threshold", |
| 256 | + type=float, |
| 257 | + default=0.5, |
| 258 | + help="Threshold for binarizing SAM2 mask outputs (default: 0.5)", |
| 259 | + ) |
| 260 | + parser.add_argument( |
| 261 | + "--iou-threshold", |
| 262 | + type=float, |
| 263 | + default=0.5, |
| 264 | + help="Minimum IoU score to accept a polygon (default: 0.5)", |
| 265 | + ) |
| 266 | + parser.add_argument( |
| 267 | + "--device", |
| 268 | + choices=["cuda", "mps", "cpu"], |
| 269 | + default=None, |
| 270 | + help="Device to use for inference (default: auto-detect mps > cuda > cpu)", |
| 271 | + ) |
| 272 | + |
| 273 | + args = parser.parse_args() |
| 274 | + |
| 275 | + logging.basicConfig( |
| 276 | + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| 277 | + ) |
| 278 | + |
| 279 | + if args.output is None: |
| 280 | + input_path = Path(args.input) |
| 281 | + output_path = input_path.parent / f"{input_path.stem}_polygons{input_path.suffix}" |
| 282 | + args.output = str(output_path) |
| 283 | + |
| 284 | + convert_boxes_to_polygons( |
| 285 | + args.input, |
| 286 | + args.output, |
| 287 | + model_name=args.model, |
| 288 | + box_batch_size=args.box_batch, |
| 289 | + image_root=args.image_root, |
| 290 | + visualize=args.visualize, |
| 291 | + viz_output_dir=args.viz_output_dir, |
| 292 | + mask_threshold=args.mask_threshold, |
| 293 | + iou_threshold=args.iou_threshold, |
| 294 | + device=args.device, |
| 295 | + ) |
| 296 | + |
| 297 | + |
| 298 | +if __name__ == "__main__": |
| 299 | + main() |
0 commit comments