|
| 1 | +"""Tests for SAM polygon generation CLI tool.""" |
| 2 | +import os |
| 3 | +import subprocess |
| 4 | +import sys |
| 5 | +from importlib.resources import files |
| 6 | + |
| 7 | +import pandas as pd |
| 8 | +from shapely import wkt |
| 9 | +from shapely.geometry import box |
| 10 | + |
| 11 | +from deepforest import get_data |
| 12 | +from deepforest.scripts.sam import ( |
| 13 | + convert_boxes_to_polygons, |
| 14 | + load_sam2_model, |
| 15 | + process_image_group, |
| 16 | +) |
| 17 | + |
| 18 | +SAM_SCRIPT = files("deepforest.scripts").joinpath("sam.py") |
| 19 | + |
| 20 | + |
| 21 | +def test_load_sam2_model(): |
| 22 | + """Test SAM2 model sucessfully loads.""" |
| 23 | + from transformers import Sam2Model, Sam2Processor |
| 24 | + |
| 25 | + model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu") |
| 26 | + |
| 27 | + assert isinstance(model, Sam2Model) |
| 28 | + assert isinstance(processor, Sam2Processor) |
| 29 | + |
| 30 | + |
| 31 | +def test_process_image_group(): |
| 32 | + """Test processing a single image with detections.""" |
| 33 | + test_csv = get_data("OSBS_029.csv") |
| 34 | + test_image_dir = os.path.dirname(get_data("OSBS_029.tif")) |
| 35 | + |
| 36 | + df = pd.read_csv(test_csv) |
| 37 | + |
| 38 | + model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu") |
| 39 | + |
| 40 | + polygons = process_image_group( |
| 41 | + image_path="OSBS_029.tif", |
| 42 | + detections=df, |
| 43 | + model=model, |
| 44 | + processor=processor, |
| 45 | + device="cpu", |
| 46 | + image_root=test_image_dir, |
| 47 | + box_batch_size=2 |
| 48 | + ) |
| 49 | + |
| 50 | + assert len(polygons) == len(df) |
| 51 | + |
| 52 | + # Verify all are valid WKT strings |
| 53 | + for poly_wkt in polygons: |
| 54 | + poly = wkt.loads(poly_wkt) |
| 55 | + assert poly is not None |
| 56 | + |
| 57 | + |
| 58 | +def test_convert_boxes_to_polygons(tmp_path): |
| 59 | + """Test the main conversion function directly.""" |
| 60 | + test_csv = get_data("OSBS_029.csv") |
| 61 | + test_image_dir = os.path.dirname(get_data("OSBS_029.tif")) |
| 62 | + output_csv = tmp_path / "polygons.csv" |
| 63 | + viz_dir = tmp_path / "viz" |
| 64 | + |
| 65 | + input_df = pd.read_csv(test_csv) |
| 66 | + |
| 67 | + convert_boxes_to_polygons( |
| 68 | + input_csv=test_csv, |
| 69 | + output_csv=str(output_csv), |
| 70 | + image_root=test_image_dir, |
| 71 | + box_batch_size=2, |
| 72 | + visualize=True, |
| 73 | + viz_output_dir=viz_dir |
| 74 | + ) |
| 75 | + |
| 76 | + assert output_csv.exists() |
| 77 | + |
| 78 | + result_df = pd.read_csv(output_csv) |
| 79 | + assert "polygon_geometry" in result_df.columns |
| 80 | + assert len(result_df) == len(input_df) |
| 81 | + |
| 82 | + # Validate all polygons are valid WKT |
| 83 | + for poly_wkt in result_df["polygon_geometry"]: |
| 84 | + poly = wkt.loads(poly_wkt) |
| 85 | + assert poly is not None |
| 86 | + |
| 87 | + assert viz_dir.exists() |
| 88 | + viz_files = list(viz_dir.glob("*.png")) |
| 89 | + assert len(viz_files) > 0, "No visualization files were created" |
| 90 | + |
| 91 | + |
| 92 | +def test_polygon_box_overlap(): |
| 93 | + """Test that output polygons overlap with input bounding boxes.""" |
| 94 | + test_csv = get_data("OSBS_029.csv") |
| 95 | + test_image_dir = os.path.dirname(get_data("OSBS_029.tif")) |
| 96 | + |
| 97 | + df = pd.read_csv(test_csv) |
| 98 | + |
| 99 | + model, processor = load_sam2_model("facebook/sam2.1-hiera-small", device="cpu") |
| 100 | + |
| 101 | + polygons = process_image_group( |
| 102 | + image_path="OSBS_029.tif", |
| 103 | + detections=df, |
| 104 | + model=model, |
| 105 | + processor=processor, |
| 106 | + device="cpu", |
| 107 | + image_root=test_image_dir, |
| 108 | + box_batch_size=2 |
| 109 | + ) |
| 110 | + |
| 111 | + # Check that each polygon overlaps with its corresponding box |
| 112 | + for idx, (_, row) in enumerate(df.iterrows()): |
| 113 | + poly = wkt.loads(polygons[idx]) |
| 114 | + |
| 115 | + # Skip empty polygons |
| 116 | + if poly.is_empty: |
| 117 | + continue |
| 118 | + |
| 119 | + # Create box polygon from detection |
| 120 | + bbox = box(row["xmin"], row["ymin"], row["xmax"], row["ymax"]) |
| 121 | + |
| 122 | + # Calculate intersection |
| 123 | + intersection = poly.intersection(bbox) |
| 124 | + |
| 125 | + # Assert there is overlap |
| 126 | + assert intersection.area > 0, f"Polygon {idx} has no overlap with its bounding box" |
| 127 | + |
| 128 | + |
| 129 | +def test_sam_cli_end_to_end(tmp_path): |
| 130 | + """Test complete CLI workflow with visualization.""" |
| 131 | + test_csv = get_data("OSBS_029.csv") |
| 132 | + test_image_dir = os.path.dirname(get_data("OSBS_029.tif")) |
| 133 | + |
| 134 | + df = pd.read_csv(test_csv) |
| 135 | + |
| 136 | + output_csv = tmp_path / "polygons.csv" |
| 137 | + viz_dir = tmp_path / "viz" |
| 138 | + |
| 139 | + args = [ |
| 140 | + sys.executable, |
| 141 | + str(SAM_SCRIPT), |
| 142 | + test_csv, |
| 143 | + "-o", str(output_csv), |
| 144 | + "--image-root", test_image_dir, |
| 145 | + "--box-batch", "2", |
| 146 | + "--device", "cpu", |
| 147 | + "--mask-threshold", "0.5", |
| 148 | + "--iou-threshold", "0.5", |
| 149 | + "--visualize", |
| 150 | + "--viz-output-dir", str(viz_dir) |
| 151 | + ] |
| 152 | + |
| 153 | + result = subprocess.run( |
| 154 | + args, |
| 155 | + stdout=subprocess.PIPE, |
| 156 | + stderr=subprocess.PIPE, |
| 157 | + text=True, |
| 158 | + timeout=300 # 5 minute timeout for model loading |
| 159 | + ) |
| 160 | + |
| 161 | + assert result.returncode == 0, f"stderr:\n{result.stderr}\nstdout:\n{result.stdout}" |
| 162 | + assert output_csv.exists(), f"Expected output file not found: {output_csv}" |
| 163 | + |
| 164 | + # Check output CSV |
| 165 | + df_out = pd.read_csv(output_csv) |
| 166 | + assert "polygon_geometry" in df_out.columns |
| 167 | + assert len(df_out) == len(df) |
| 168 | + |
| 169 | + # Verify polygons are valid WKT |
| 170 | + valid_polygons = 0 |
| 171 | + for poly_wkt in df_out["polygon_geometry"]: |
| 172 | + poly = wkt.loads(poly_wkt) |
| 173 | + assert poly is not None |
| 174 | + if not poly.is_empty: |
| 175 | + valid_polygons += 1 |
| 176 | + |
| 177 | + # At least one polygon should be non-empty |
| 178 | + assert valid_polygons > 0, "All polygons are empty" |
| 179 | + |
| 180 | + # Check visualization created |
| 181 | + assert viz_dir.exists() |
| 182 | + viz_files = list(viz_dir.glob("*.png")) |
| 183 | + assert len(viz_files) > 0, "No visualization files were created" |
0 commit comments