-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCustomMatchingEvaluator.py
More file actions
189 lines (167 loc) · 8.04 KB
/
CustomMatchingEvaluator.py
File metadata and controls
189 lines (167 loc) · 8.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import numpy as np
import torch
from detectron2.evaluation.evaluator import DatasetEvaluator
from detectron2.structures import Boxes, pairwise_iou
from detectron2.utils import comm
def perform_box_matching(gt_boxes: Boxes, pred_boxes: Boxes, iou_threshold: float = 0.5):
"""
Perform a greedy matching between ground truth and predicted boxes.
Args:
gt_boxes (Boxes): ground truth boxes.
pred_boxes (Boxes): predicted boxes.
iou_threshold (float): minimum IoU required to consider a match.
Returns:
matches (list of tuples): each tuple is (gt_index, pred_index)
"""
if len(gt_boxes) == 0 or len(pred_boxes) == 0:
return []
# Move gt_boxes to the device of pred_boxes
gt_boxes = Boxes(gt_boxes.tensor.to(pred_boxes.tensor.device))
# Compute pairwise IoU matrix; shape: (num_pred, num_gt)
iou_matrix = pairwise_iou(pred_boxes, gt_boxes)
matches = []
used_pred = set()
num_gt = len(gt_boxes)
for gt_idx in range(num_gt):
best_iou = 0.0
best_pred = -1
for pred_idx in range(len(pred_boxes)):
if pred_idx in used_pred:
continue
iou = iou_matrix[pred_idx, gt_idx]
if iou >= iou_threshold and iou > best_iou:
best_iou = iou
best_pred = pred_idx
if best_pred >= 0:
matches.append((gt_idx, best_pred))
used_pred.add(best_pred)
return matches
class CustomMatchingEvaluator(DatasetEvaluator):
"""
A custom evaluator that:
1. Retrieves for each image the ground truth boxes and predicted boxes.
2. Performs matching using a greedy IoU-based algorithm with a specified threshold.
3. Returns a dictionary where each key is an image_id and the value is the list
of matched pairs (gt_index, pred_index).
"""
def __init__(self, iou_threshold: float = 0.5):
self.iou_threshold = iou_threshold
self.reset()
def reset(self):
# We'll use a dictionary to store per-image data.
# For each image_id, we store a dict with:
# "gt_boxes" (Boxes) and "pred_boxes" (Boxes)
self.data = {}
#----------------------------------------------------------------------------
# Each item: [pred_box(x0,y0,x1,y1), gt_aperture, pred_aperture, abs_error]
self.per_match_details = []
#----------------------------------------------------------------------------
def process(self, inputs, outputs):
"""
Expects each input dict to have an "instances" field in ground truth (with gt_boxes)
and each output to have an "instances" field in predictions (with pred_boxes).
"""
for input_dict, output in zip(inputs, outputs):
image_id = input_dict.get("image_id", input_dict.get("file_name", "unknown"))
# Process ground truth
gt_instances = input_dict.get("instances", None)
if gt_instances is not None and hasattr(gt_instances, "gt_boxes"):
gt_boxes = gt_instances.gt_boxes
else:
gt_boxes = Boxes(torch.empty((0, 4)))
# Process predictions
pred_instances = output.get("instances", None)
if pred_instances is not None and hasattr(pred_instances, "pred_boxes"):
pred_boxes = pred_instances.pred_boxes
else:
pred_boxes = Boxes(torch.empty((0, 4)))
# Store data for this image
self.data[image_id] = {
"gt_boxes": gt_boxes,
"pred_boxes": pred_boxes,
# Optionally add apertures if needed:
"gt_apertures": gt_instances.gt_aperture.cpu().numpy().flatten() if (gt_instances is not None and hasattr(gt_instances, "gt_aperture")) else np.array([]),
"pred_apertures": pred_instances.aperture.cpu().numpy().flatten() if (pred_instances is not None and hasattr(pred_instances, "aperture")) else np.array([]),
}
def evaluate(self):
"""
For each image in self.data, perform box matching and compute:
- Matches (TP)
- False Positives (FP)
- False Negatives (FN)
Also compute aperture errors on the matched pairs.
Returns:
A dictionary containing:
- The matches per image.
- The aggregated aperture evaluation (mean absolute error).
- Detection metrics: total TP, FP, FN, Precision, and Recall.
"""
all_data = comm.all_gather(self.data)
if not comm.is_main_process():
return {}
# Merge all dictionaries
merged_data = {}
for data in all_data:
merged_data.update(data)
self.data = merged_data
results = {}
aperture_errors = []
#----------------------------------------------------------------------------
# Collect all per-match details here, then publish to self.per_match_details
per_match_details = []
#----------------------------------------------------------------------------
# Counters for detection metrics (across all images)
total_tp = 0
total_fp = 0
total_fn = 0
for image_id, item in self.data.items():
gt_boxes = item["gt_boxes"]
pred_boxes = item["pred_boxes"]
gt_apertures = item["gt_apertures"]
pred_apertures = item["pred_apertures"]
# Perform box matching based on IoU:
matches = perform_box_matching(gt_boxes, pred_boxes, self.iou_threshold)
#results[image_id] = matches
# Detection metrics for this image:
tp = len(matches) # True positives: number of matched boxes.
fp = len(pred_boxes) - tp # False positives: predictions that didn't match.
fn = len(gt_boxes) - tp # False negatives: ground truth objects with no match.
total_tp += tp
total_fp += fp
total_fn += fn
# Compute aperture errors for the matched boxes.
for gt_idx, pred_idx in matches:
# Make sure indices are within bounds.
if gt_idx < len(gt_apertures) and pred_idx < len(pred_apertures):
# Denormalize: multiply by 80 to convert from [0, 1] to [0, 80] degrees.
gt_aperture = gt_apertures[gt_idx]
pred_aperture = pred_apertures[pred_idx]
# Debug print (optional):
error = abs(gt_aperture - pred_aperture)
aperture_errors.append(error)
#----------------------------------------------------------------------------
# Grab the matched predicted box
box = pred_boxes.tensor[pred_idx].detach().cpu().numpy().tolist() # [x0,y0,x1,y1]
y0, y1 = float(box[1]), float(box[3])
distance = ((y1 - y0) + y0) * (103.0 / 512.0)
per_match_details.append([box, float(distance), gt_aperture, pred_aperture, float(error)])
#----------------------------------------------------------------------------
# Compute overall detection precision and recall.
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
# Compute mean aperture error.
mean_aperture_error = np.mean(aperture_errors) if aperture_errors else float("nan")
# Aggregate the results.
results["aperture_evaluation"] = {
"mean_absolute_error": mean_aperture_error,
"num_matches": len(aperture_errors),
}
results["detection_metrics"] = {
"true_positives": total_tp,
"false_positives": total_fp,
"false_negatives": total_fn,
"precision": precision,
"recall": recall,
}
self.per_match_details = per_match_details
return results