Skip to content

Commit f275688

Browse files
add tests
1 parent ba7588b commit f275688

File tree

3 files changed

+200
-7
lines changed

3 files changed

+200
-7
lines changed

src/deepforest/scripts/sam.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
def load_sam2_model(model_name: str, device: str):
23+
def load_sam2_model(model_name: str, device: str) -> tuple[Sam2Model, Sam2Processor]:
2424
"""Load SAM2 model and processor from HuggingFace.
2525
2626
Args:
@@ -39,15 +39,15 @@ def load_sam2_model(model_name: str, device: str):
3939
def process_image_group(
4040
image_path: str,
4141
detections: pd.DataFrame,
42-
model,
43-
processor,
42+
model: Sam2Model,
43+
processor: Sam2Processor,
4444
device: str,
4545
image_root: str = "",
4646
box_batch_size: int = 32,
4747
mask_threshold: float = 0.5,
4848
iou_threshold: float = 0.5,
49-
viz_output_dir: str = None,
50-
) -> list:
49+
viz_output_dir: str | None = None,
50+
) -> list[str]:
5151
"""Process all detections for a single image.
5252
5353
Args:
@@ -149,7 +149,7 @@ def convert_boxes_to_polygons(
149149
viz_output_dir: str = ".",
150150
mask_threshold: float = 0.5,
151151
iou_threshold: float = 0.5,
152-
device: str = None,
152+
device: str | None = None,
153153
) -> None:
154154
"""Convert DeepForest bounding boxes to polygons using SAM2.
155155
@@ -214,7 +214,7 @@ def convert_boxes_to_polygons(
214214
logger.info("Saved results to %s", output_csv)
215215

216216

217-
def main():
217+
def main() -> None:
218218
"""CLI entrypoint for polygon conversion."""
219219
parser = argparse.ArgumentParser(
220220
description="Convert DeepForest bounding boxes to polygons using SAM2"

tests/test_sam.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Tests for SAM polygon generation CLI tool."""
2+
import os
3+
import subprocess
4+
import sys
5+
from importlib.resources import files
6+
7+
import pandas as pd
8+
from shapely import wkt
9+
from shapely.geometry import box
10+
11+
from deepforest import get_data
12+
from deepforest.scripts.sam import (
13+
convert_boxes_to_polygons,
14+
load_sam2_model,
15+
process_image_group,
16+
)
17+
18+
SAM_SCRIPT = files("deepforest.scripts").joinpath("sam.py")
19+
20+
21+
def test_load_sam2_model():
22+
"""Test SAM2 model sucessfully loads."""
23+
from transformers import Sam2Model, Sam2Processor
24+
25+
model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu")
26+
27+
assert isinstance(model, Sam2Model)
28+
assert isinstance(processor, Sam2Processor)
29+
30+
31+
def test_process_image_group():
32+
"""Test processing a single image with detections."""
33+
test_csv = get_data("OSBS_029.csv")
34+
test_image_dir = os.path.dirname(get_data("OSBS_029.tif"))
35+
36+
df = pd.read_csv(test_csv)
37+
38+
model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu")
39+
40+
polygons = process_image_group(
41+
image_path="OSBS_029.tif",
42+
detections=df,
43+
model=model,
44+
processor=processor,
45+
device="cpu",
46+
image_root=test_image_dir,
47+
box_batch_size=2
48+
)
49+
50+
assert len(polygons) == len(df)
51+
52+
# Verify all are valid WKT strings
53+
for poly_wkt in polygons:
54+
poly = wkt.loads(poly_wkt)
55+
assert poly is not None
56+
57+
58+
def test_convert_boxes_to_polygons(tmp_path):
59+
"""Test the main conversion function directly."""
60+
test_csv = get_data("OSBS_029.csv")
61+
test_image_dir = os.path.dirname(get_data("OSBS_029.tif"))
62+
output_csv = tmp_path / "polygons.csv"
63+
viz_dir = tmp_path / "viz"
64+
65+
input_df = pd.read_csv(test_csv)
66+
67+
convert_boxes_to_polygons(
68+
input_csv=test_csv,
69+
output_csv=str(output_csv),
70+
image_root=test_image_dir,
71+
box_batch_size=2,
72+
visualize=True,
73+
viz_output_dir=viz_dir
74+
)
75+
76+
assert output_csv.exists()
77+
78+
result_df = pd.read_csv(output_csv)
79+
assert "polygon_geometry" in result_df.columns
80+
assert len(result_df) == len(input_df)
81+
82+
# Validate all polygons are valid WKT
83+
for poly_wkt in result_df["polygon_geometry"]:
84+
poly = wkt.loads(poly_wkt)
85+
assert poly is not None
86+
87+
assert viz_dir.exists()
88+
viz_files = list(viz_dir.glob("*.png"))
89+
assert len(viz_files) > 0, "No visualization files were created"
90+
91+
92+
def test_polygon_box_overlap():
93+
"""Test that output polygons overlap with input bounding boxes."""
94+
test_csv = get_data("OSBS_029.csv")
95+
test_image_dir = os.path.dirname(get_data("OSBS_029.tif"))
96+
97+
df = pd.read_csv(test_csv)
98+
99+
model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu")
100+
101+
polygons = process_image_group(
102+
image_path="OSBS_029.tif",
103+
detections=df,
104+
model=model,
105+
processor=processor,
106+
device="cpu",
107+
image_root=test_image_dir,
108+
box_batch_size=2
109+
)
110+
111+
# Check that each polygon overlaps with its corresponding box
112+
for idx, (_, row) in enumerate(df.iterrows()):
113+
poly = wkt.loads(polygons[idx])
114+
115+
# Skip empty polygons
116+
if poly.is_empty:
117+
continue
118+
119+
# Create box polygon from detection
120+
bbox = box(row["xmin"], row["ymin"], row["xmax"], row["ymax"])
121+
122+
# Calculate intersection
123+
intersection = poly.intersection(bbox)
124+
125+
# Assert there is overlap
126+
assert intersection.area > 0, f"Polygon {idx} has no overlap with its bounding box"
127+
128+
129+
def test_sam_cli_end_to_end(tmp_path):
130+
"""Test complete CLI workflow with visualization."""
131+
test_csv = get_data("OSBS_029.csv")
132+
test_image_dir = os.path.dirname(get_data("OSBS_029.tif"))
133+
134+
df = pd.read_csv(test_csv)
135+
136+
output_csv = tmp_path / "polygons.csv"
137+
viz_dir = tmp_path / "viz"
138+
139+
args = [
140+
sys.executable,
141+
str(SAM_SCRIPT),
142+
test_csv,
143+
"-o", str(output_csv),
144+
"--image-root", test_image_dir,
145+
"--box-batch", "2",
146+
"--device", "cpu",
147+
"--mask-threshold", "0.5",
148+
"--iou-threshold", "0.5",
149+
"--visualize",
150+
"--viz-output-dir", str(viz_dir)
151+
]
152+
153+
result = subprocess.run(
154+
args,
155+
stdout=subprocess.PIPE,
156+
stderr=subprocess.PIPE,
157+
text=True,
158+
timeout=300 # 5 minute timeout for model loading
159+
)
160+
161+
assert result.returncode == 0, f"stderr:\n{result.stderr}\nstdout:\n{result.stdout}"
162+
assert output_csv.exists(), f"Expected output file not found: {output_csv}"
163+
164+
# Check output CSV
165+
df_out = pd.read_csv(output_csv)
166+
assert "polygon_geometry" in df_out.columns
167+
assert len(df_out) == len(df)
168+
169+
# Verify polygons are valid WKT
170+
valid_polygons = 0
171+
for poly_wkt in df_out["polygon_geometry"]:
172+
poly = wkt.loads(poly_wkt)
173+
assert poly is not None
174+
if not poly.is_empty:
175+
valid_polygons += 1
176+
177+
# At least one polygon should be non-empty
178+
assert valid_polygons > 0, "All polygons are empty"
179+
180+
# Check visualization created
181+
assert viz_dir.exists()
182+
viz_files = list(viz_dir.glob("*.png"))
183+
assert len(viz_files) > 0, "No visualization files were created"

tests/test_utilities.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,13 @@ def test_format_geometry_polygon():
652652
# Format geometry should raise ValueError since polygon predictions are not supported
653653
with pytest.raises(ValueError, match="Polygon predictions are not yet supported for formatting"):
654654
utilities.format_geometry(prediction, geom_type="polygon")
655+
656+
def test_empty_mask_to_poly():
657+
"""Test handling of empty masks."""
658+
659+
empty_mask = np.zeros((100, 100), dtype=np.uint8)
660+
661+
polygon = utilities.mask_to_polygon(empty_mask)
662+
663+
assert isinstance(polygon, shapely.geometry.Polygon)
664+
assert polygon.is_empty

0 commit comments

Comments
 (0)