88import torch
99
1010from deepforest import main as df_main
11- from deepforest .utilities import read_file
11+ from deepforest .utilities import read_file , format_geometry
1212from deepforest .visualize import plot_results
1313
1414from milliontrees import get_dataset
@@ -61,20 +61,20 @@ def format_deepforest_predictions(
6161 "images" )
6262 formatted_pred ["image_path" ] = basename
6363 else :
64- pred .root_dir = os .path .join (dataset ._data_dir ._str , "images" )
65- pred ["image_path" ] = basename
64+ formatted_pred = format_geometry (pred )
65+ formatted_pred .root_dir = os .path .join (dataset ._data_dir ._str ,
66+ "images" )
67+ formatted_pred ["image_path" ] = basename
6668
6769 y_pred = {
6870 "y" : torch .tensor (
69- pred [["xmin" , "ymin" , "xmax" , "ymax" ]].values .astype (
71+ formatted_pred [["xmin" , "ymin" , "xmax" , "ymax" ]].values .astype (
7072 "float32" )),
71- "labels" : torch .tensor (pred .label .values .astype (np .int64 )),
73+ "labels" : torch .tensor (formatted_pred .label .values .astype (np .int64 )),
7274 "scores" : torch .tensor (
73- pred .score .values .astype ("float32" )),
75+ formatted_pred .score .values .astype ("float32" )),
7476 }
7577
76- formatted_pred = pred
77-
7878 batch_y_pred .append (y_pred )
7979 formatted_predictions .append (formatted_pred )
8080
@@ -94,8 +94,8 @@ def plot_eval_result(
9494
9595 # Ground truth
9696 gt_df = read_file (
97- pd .DataFrame (image_targets ["y" ]. numpy () ,
98- columns = ["xmin" , "ymin" , "xmax" , "ymax" ]))
97+ pd .DataFrame (image_targets ["bboxes" ] ,
98+ columns = ["xmin" , "ymin" , "xmax" , "ymax" ]), label = "Tree" )
9999 gt_df ["label" ] = "Tree"
100100
101101 # Predictions
@@ -107,8 +107,8 @@ def plot_eval_result(
107107 image = image_tensor .permute (1 , 2 , 0 ).numpy () * 255
108108
109109 # Simple recall example for logging
110- recall = dataset .metrics ["recall" ]._recall (image_targets ["y " ],
111- y_pred .get ("y " ,
110+ recall = dataset .metrics ["recall" ]._recall (image_targets ["bboxes " ],
111+ y_pred .get ("bboxes " ,
112112 torch .zeros (
113113 (0 , 4 ))),
114114 iou_threshold = 0.3 )
@@ -159,6 +159,7 @@ def main():
159159 # Load model
160160 model = df_main .deepforest ()
161161 model .load_model ("weecology/deepforest-tree" )
162+ model .eval ()
162163
163164 # Load dataset
164165 polygon_dataset = get_dataset ("TreePolygons" ,
0 commit comments