Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from deepforest import main, get_data

def test_benchmark_release():
"""
Benchmark test to ensure the specific release version of the model
produces consistent results.
"""
# Load the model using a SPECIFIC revision (Commit SHA)
release_sha = "cc21436bc5d572dde8ff5f93c1e71a32f563cace"

m = main.deepforest()
m.load_model("weecology/deepforest-tree", revision=release_sha)

csv_file = get_data("OSBS_029.csv")
results = m.evaluate(csv_file, iou_threshold=0.4)

# Strict Assertions (for The "Benchmark")
assert results["box_precision"] == pytest.approx(0.8, abs=0.01)
assert results["box_recall"] == pytest.approx(0.7213, abs=0.01)
11 changes: 7 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,16 @@ def test_predict_tile_from_array(m, path):

assert not prediction.empty

def test_evaluate(m, tmpdir):
def test_evaluate(m):
csv_file = get_data("OSBS_029.csv")
results = m.evaluate(csv_file, iou_threshold=0.4)

assert np.round(results["box_precision"], 2) > 0.5
assert np.round(results["box_recall"], 2) > 0.5
# Relaxed assertions (Sanity Check only)
# Allows model improvements without breaking tests
assert results["box_precision"] > 0.7
assert results["box_recall"] > 0.5

# Structure and Label checks
assert len(results["results"].predicted_label.dropna().unique()) == 1
assert results["results"].predicted_label.dropna().unique()[0] == "Tree"
assert results["predictions"].shape[0] > 0
Expand All @@ -525,7 +529,6 @@ def test_evaluate(m, tmpdir):
df = pd.read_csv(csv_file)
assert results["results"].shape[0] == df.shape[0]


def test_train_callbacks(m):
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
Expand Down