Skip to content

Commit ba7588b

Browse files
add SAM box processing script
1 parent a13dee4 commit ba7588b

File tree

3 files changed

+327
-0
lines changed

3 files changed

+327
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ docs = [
102102

103103
[project.scripts]
104104
deepforest = "deepforest.scripts.cli:main"
105+
deepforest-sam = "deepforest.scripts.sam:main"
105106

106107
[build-system]
107108
requires = ["setuptools>=61.0", "wheel"]

src/deepforest/scripts/sam.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""Convert DeepForest bounding box predictions to polygons using SAM2."""
2+
3+
import argparse
4+
import logging
5+
import os
6+
from pathlib import Path
7+
8+
import numpy as np
9+
import pandas as pd
10+
import torch
11+
from PIL import Image
12+
from shapely import wkt
13+
from shapely.geometry import Polygon
14+
from tqdm import tqdm
15+
from transformers import Sam2Model, Sam2Processor
16+
17+
from deepforest.utilities import mask_to_polygon
18+
from deepforest.visualize import plot_results
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def load_sam2_model(model_name: str, device: str):
24+
"""Load SAM2 model and processor from HuggingFace.
25+
26+
Args:
27+
model_name: Name of the SAM2 model on HuggingFace
28+
device: Device to load model on ('cuda', 'mps', or 'cpu')
29+
30+
Returns:
31+
Tuple of (model, processor)
32+
"""
33+
processor = Sam2Processor.from_pretrained(model_name)
34+
model = Sam2Model.from_pretrained(model_name)
35+
model = model.to(device)
36+
return model, processor
37+
38+
39+
def process_image_group(
40+
image_path: str,
41+
detections: pd.DataFrame,
42+
model,
43+
processor,
44+
device: str,
45+
image_root: str = "",
46+
box_batch_size: int = 32,
47+
mask_threshold: float = 0.5,
48+
iou_threshold: float = 0.5,
49+
viz_output_dir: str = None,
50+
) -> list:
51+
"""Process all detections for a single image.
52+
53+
Args:
54+
image_path: Path to the image file
55+
detections: DataFrame of detections for this image
56+
model: SAM2 model
57+
processor: SAM2 processor
58+
device: Device to run inference on
59+
image_root: Root directory to prepend to image_path if needed
60+
box_batch_size: Maximum number of boxes to process per forward pass
61+
mask_threshold: Threshold for binarizing SAM2 mask outputs
62+
iou_threshold: Minimum IoU score to accept a polygon
63+
viz_output_dir: Directory to save visualizations (if not None)
64+
65+
Returns:
66+
List of WKT polygon strings
67+
"""
68+
full_path = os.path.join(image_root, image_path) if image_root else image_path
69+
image = Image.open(full_path).convert("RGB")
70+
71+
boxes = detections[["xmin", "ymin", "xmax", "ymax"]].values.tolist()
72+
73+
all_polygons = []
74+
for i in range(0, len(boxes), box_batch_size):
75+
box_chunk = boxes[i : i + box_batch_size]
76+
input_boxes = [box_chunk]
77+
78+
inputs = processor(images=image, input_boxes=input_boxes, return_tensors="pt").to(
79+
device
80+
)
81+
82+
with torch.no_grad():
83+
outputs = model(**inputs)
84+
85+
masks = processor.post_process_masks(
86+
outputs.pred_masks.cpu(),
87+
inputs["original_sizes"],
88+
binarize=False,
89+
mask_interpolation_mode="nearest",
90+
)[0]
91+
92+
iou_scores = outputs.iou_scores.cpu()
93+
94+
for i, mask_set in enumerate(masks):
95+
best_idx = iou_scores[0, i].argmax().item()
96+
best_iou = iou_scores[0, i, best_idx].item()
97+
98+
if best_iou < iou_threshold:
99+
all_polygons.append(Polygon().wkt)
100+
continue
101+
102+
best_mask = mask_set[best_idx]
103+
mask_np = (best_mask.numpy() > mask_threshold).astype(np.uint8)
104+
polygon = mask_to_polygon(mask_np)
105+
all_polygons.append(polygon.wkt)
106+
107+
if viz_output_dir is not None:
108+
viz_df = detections.copy()
109+
viz_df["polygon_geometry"] = all_polygons
110+
viz_df["geometry"] = viz_df["polygon_geometry"].apply(
111+
lambda x: wkt.loads(x) if pd.notna(x) else None
112+
)
113+
viz_df = viz_df[
114+
viz_df["geometry"].apply(lambda x: x is not None and not x.is_empty)
115+
]
116+
117+
if len(viz_df) > 0:
118+
if "label" not in viz_df.columns:
119+
viz_df["label"] = "Tree"
120+
if "score" not in viz_df.columns:
121+
viz_df["score"] = 1.0
122+
123+
full_path = os.path.join(image_root, image_path) if image_root else image_path
124+
with Image.open(full_path) as img:
125+
width, height = img.size
126+
127+
image_name = Path(image_path).stem
128+
viz_path = os.path.join(viz_output_dir, f"{image_name}_polygons.png")
129+
plot_results(
130+
results=viz_df,
131+
image=full_path,
132+
savedir=os.path.dirname(viz_path),
133+
basename=os.path.splitext(os.path.basename(viz_path))[0],
134+
height=height,
135+
width=width,
136+
show=False,
137+
)
138+
139+
return all_polygons
140+
141+
142+
def convert_boxes_to_polygons(
143+
input_csv: str,
144+
output_csv: str,
145+
model_name: str = "facebook/sam2.1-hiera-small",
146+
box_batch_size: int = 32,
147+
image_root: str = "",
148+
visualize: bool = False,
149+
viz_output_dir: str = ".",
150+
mask_threshold: float = 0.5,
151+
iou_threshold: float = 0.5,
152+
device: str = None,
153+
) -> None:
154+
"""Convert DeepForest bounding boxes to polygons using SAM2.
155+
156+
Args:
157+
input_csv: Path to input CSV with DeepForest predictions
158+
output_csv: Path to save output CSV with polygons
159+
model_name: HuggingFace model name for SAM2
160+
box_batch_size: Maximum number of boxes to process per forward pass
161+
image_root: Root directory to prepend to image paths in CSV
162+
visualize: Whether to create visualization images
163+
viz_output_dir: Directory to save visualization images
164+
mask_threshold: Threshold for binarizing SAM2 mask outputs
165+
iou_threshold: Minimum IoU score to accept a polygon
166+
device: Device to use ('cuda', 'mps', or 'cpu'). Auto-detects if None.
167+
"""
168+
df = pd.read_csv(input_csv)
169+
170+
required_cols = ["xmin", "ymin", "xmax", "ymax", "image_path"]
171+
missing_cols = [col for col in required_cols if col not in df.columns]
172+
if missing_cols:
173+
raise ValueError(f"Missing required columns: {missing_cols}")
174+
175+
if device is None:
176+
if torch.backends.mps.is_available():
177+
device = "mps"
178+
elif torch.cuda.is_available():
179+
device = "cuda"
180+
else:
181+
device = "cpu"
182+
183+
logger.info("Using device: %s", device)
184+
logger.info("Loading SAM2 model: %s", model_name)
185+
model, processor = load_sam2_model(model_name, device)
186+
187+
grouped = df.groupby("image_path")
188+
total_images = len(grouped)
189+
190+
all_polygons = []
191+
192+
if visualize:
193+
os.makedirs(viz_output_dir, exist_ok=True)
194+
195+
for image_path, group in tqdm(grouped, desc="Processing images", total=total_images):
196+
polygons = process_image_group(
197+
image_path,
198+
group,
199+
model,
200+
processor,
201+
device,
202+
image_root,
203+
box_batch_size,
204+
mask_threshold,
205+
iou_threshold,
206+
viz_output_dir=viz_output_dir if visualize else None,
207+
)
208+
all_polygons.extend(polygons)
209+
210+
df["polygon_geometry"] = all_polygons
211+
212+
os.makedirs(os.path.dirname(output_csv), exist_ok=True)
213+
df.to_csv(output_csv, index=False)
214+
logger.info("Saved results to %s", output_csv)
215+
216+
217+
def main():
218+
"""CLI entrypoint for polygon conversion."""
219+
parser = argparse.ArgumentParser(
220+
description="Convert DeepForest bounding boxes to polygons using SAM2"
221+
)
222+
parser.add_argument("input", help="Path to input CSV with DeepForest predictions")
223+
parser.add_argument(
224+
"-o",
225+
"--output",
226+
help="Path to output CSV (default: input with '_polygons' suffix)",
227+
)
228+
parser.add_argument(
229+
"--model",
230+
default="facebook/sam2.1-hiera-small",
231+
help="SAM2 model name from HuggingFace (default: facebook/sam2.1-hiera-small)",
232+
)
233+
parser.add_argument(
234+
"--box-batch",
235+
type=int,
236+
default=32,
237+
help="Maximum number of boxes to process per forward pass (default: 32)",
238+
)
239+
parser.add_argument(
240+
"--image-root",
241+
default="",
242+
help="Root directory to prepend to image paths in CSV",
243+
)
244+
parser.add_argument(
245+
"--visualize",
246+
action="store_true",
247+
help="Create visualization images with polygons overlaid",
248+
)
249+
parser.add_argument(
250+
"--viz-output-dir",
251+
default=".",
252+
help="Directory to save visualization images (default: current directory)",
253+
)
254+
parser.add_argument(
255+
"--mask-threshold",
256+
type=float,
257+
default=0.5,
258+
help="Threshold for binarizing SAM2 mask outputs (default: 0.5)",
259+
)
260+
parser.add_argument(
261+
"--iou-threshold",
262+
type=float,
263+
default=0.5,
264+
help="Minimum IoU score to accept a polygon (default: 0.5)",
265+
)
266+
parser.add_argument(
267+
"--device",
268+
choices=["cuda", "mps", "cpu"],
269+
default=None,
270+
help="Device to use for inference (default: auto-detect mps > cuda > cpu)",
271+
)
272+
273+
args = parser.parse_args()
274+
275+
logging.basicConfig(
276+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
277+
)
278+
279+
if args.output is None:
280+
input_path = Path(args.input)
281+
output_path = input_path.parent / f"{input_path.stem}_polygons{input_path.suffix}"
282+
args.output = str(output_path)
283+
284+
convert_boxes_to_polygons(
285+
args.input,
286+
args.output,
287+
model_name=args.model,
288+
box_batch_size=args.box_batch,
289+
image_root=args.image_root,
290+
visualize=args.visualize,
291+
viz_output_dir=args.viz_output_dir,
292+
mask_threshold=args.mask_threshold,
293+
iou_threshold=args.iou_threshold,
294+
device=args.device,
295+
)
296+
297+
298+
if __name__ == "__main__":
299+
main()

src/deepforest/utilities.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import warnings
44

5+
import cv2
56
import geopandas as gpd
67
import numpy as np
78
import pandas as pd
@@ -10,12 +11,38 @@
1011
import xmltodict
1112
from omegaconf import DictConfig, OmegaConf
1213
from PIL import Image
14+
from shapely.geometry import Polygon
1315
from tqdm import tqdm
1416

1517
from deepforest import _ROOT
1618
from deepforest.conf.schema import Config as StructuredConfig
1719

1820

21+
def mask_to_polygon(mask: np.ndarray) -> Polygon:
22+
"""Convert a binary mask to a shapely Polygon.
23+
24+
Args:
25+
mask: Binary mask array (H, W)
26+
27+
Returns:
28+
Shapely Polygon representing the mask contour
29+
"""
30+
contours, _ = cv2.findContours(
31+
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
32+
)
33+
34+
if len(contours) == 0:
35+
return Polygon()
36+
37+
largest_contour = max(contours, key=lambda x: x.shape[0])
38+
coords = largest_contour.squeeze()
39+
40+
if len(coords.shape) == 1 or coords.shape[0] < 3:
41+
return Polygon()
42+
43+
return Polygon(coords)
44+
45+
1946
def load_config(
2047
config_name: str = "config.yaml",
2148
overrides: DictConfig | dict = None,

0 commit comments

Comments
 (0)