Skip to content

Commit acbaf4f

Browse files
panagiotamoraitipre-commit-ci[bot]BordaCopilot
authored
Correct confusion matrix calculation-function evaluate_detection_batch (#1853)
* Correct confusion matrix calculation-function evaluate_detection_batch * Correct confusion matrix computation * Add 3 more tests for empty Detections-GTs * Minor changes in too-long lines corrected * Replace deprecated mock_detections with _create_detections * Fix indentation in test assertions and add missing `self` parameter in `test_confusion_matrix`. * Update metric computations to improve numerical stability and replace deprecated NumPy functions * Add IoU+class matching tests and synthetic dataset fixtures for detection metrics * Use `Optional` for type hinting `classes` parameter in `_yolo_dataset_factory`. Import `Optional` from `typing`. * Refactor detection metric tests to simplify confusion matrix assertions and enhance IoU+class matching validation. * Apply suggestions from code review * fix(pre_commit): 🎨 auto format pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 124be19 commit acbaf4f

File tree

7 files changed

+823
-63
lines changed

7 files changed

+823
-63
lines changed

src/supervision/metrics/detection.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def from_detections(
148148
149149
```
150150
"""
151-
152151
prediction_tensors = []
153152
target_tensors = []
154153
for prediction, target in zip(predictions, targets):
@@ -274,9 +273,28 @@ def evaluate_detection_batch(
274273
"""
275274
result_matrix = np.zeros((num_classes + 1, num_classes + 1))
276275

276+
# Filter predictions by confidence threshold
277277
conf_idx = 5
278278
confidence = predictions[:, conf_idx]
279-
detection_batch_filtered = predictions[confidence > conf_threshold]
279+
detection_batch_filtered = predictions[confidence >= conf_threshold]
280+
281+
if len(detection_batch_filtered) == 0:
282+
# No detections pass confidence threshold - all GT are FN
283+
class_id_idx = 4
284+
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
285+
for gt_class in true_classes:
286+
result_matrix[gt_class, num_classes] += 1
287+
return result_matrix
288+
289+
if len(targets) == 0:
290+
# No ground truth - all detections are FP
291+
class_id_idx = 4
292+
detection_classes = np.array(
293+
detection_batch_filtered[:, class_id_idx], dtype=np.int16
294+
)
295+
for det_class in detection_classes:
296+
result_matrix[num_classes, det_class] += 1
297+
return result_matrix
280298

281299
class_id_idx = 4
282300
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
@@ -286,37 +304,71 @@ def evaluate_detection_batch(
286304
true_boxes = targets[:, :class_id_idx]
287305
detection_boxes = detection_batch_filtered[:, :class_id_idx]
288306

307+
# Calculate IoU matrix
289308
iou_batch = box_iou_batch(
290309
boxes_true=true_boxes, boxes_detection=detection_boxes
291310
)
292-
matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
293-
294-
if matched_idx[0].shape[0]:
295-
matches = np.stack(
296-
(matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
297-
)
298-
matches = ConfusionMatrix._drop_extra_matches(matches=matches)
299-
else:
300-
matches = np.zeros((0, 3))
301311

302-
matched_true_idx, matched_detection_idx, _ = matches.transpose().astype(
303-
np.int16
304-
)
312+
# Find all valid matches (IoU > threshold, regardless of class)
313+
# Use vectorized operations to avoid nested Python loops
314+
iou_mask = iou_batch > iou_threshold
315+
gt_indices, det_indices = np.nonzero(iou_mask)
305316

306-
for i, true_class_value in enumerate(true_classes):
307-
j = matched_true_idx == i
308-
if matches.shape[0] > 0 and sum(j) == 1:
309-
result_matrix[
310-
true_class_value, detection_classes[matched_detection_idx[j]]
311-
] += 1 # TP
312-
else:
313-
result_matrix[true_class_value, num_classes] += 1 # FN
314-
315-
for i, detection_class_value in enumerate(detection_classes):
316-
if not any(matched_detection_idx == i):
317-
result_matrix[num_classes, detection_class_value] += 1 # FP
318-
final_result_matrix: npt.NDArray[np.int32] = result_matrix
319-
return final_result_matrix
317+
# If no pairs exceed the IoU threshold, skip matching
318+
if gt_indices.size == 0:
319+
valid_matches = []
320+
else:
321+
ious = iou_batch[gt_indices, det_indices]
322+
gt_match_classes = true_classes[gt_indices]
323+
det_match_classes = detection_classes[det_indices]
324+
class_matches = gt_match_classes == det_match_classes
325+
326+
# Sort matches by class match first (True before False),
327+
# then by IoU descending.
328+
# np.lexsort sorts by the last key first, in ascending order.
329+
# We use ~class_matches so that True becomes 0
330+
# and False becomes 1 (True first),
331+
# and -ious so that larger IoUs come first.
332+
sort_indices = np.lexsort((-ious, ~class_matches))
333+
334+
# Build list of matches in the same format as before:
335+
# (gt_idx, det_idx, iou, class_match)
336+
valid_matches = [
337+
(
338+
int(gt_indices[idx]),
339+
int(det_indices[idx]),
340+
float(ious[idx]),
341+
bool(class_matches[idx]),
342+
)
343+
for idx in sort_indices
344+
]
345+
# Greedily assign matches, ensuring each GT
346+
# and detection is matched at most once
347+
matched_gt_idx = set()
348+
matched_det_idx = set()
349+
350+
for gt_idx, det_idx, iou, class_match in valid_matches:
351+
if gt_idx not in matched_gt_idx and det_idx not in matched_det_idx:
352+
# Valid spatial match - record the class prediction
353+
gt_class = true_classes[gt_idx]
354+
det_class = detection_classes[det_idx]
355+
356+
# This handles both correct classification (TP) and misclassification
357+
result_matrix[gt_class, det_class] += 1
358+
matched_gt_idx.add(gt_idx)
359+
matched_det_idx.add(det_idx)
360+
361+
# Count unmatched ground truth as FN
362+
for gt_idx, gt_class in enumerate(true_classes):
363+
if gt_idx not in matched_gt_idx:
364+
result_matrix[gt_class, num_classes] += 1
365+
366+
# Count unmatched detections as FP
367+
for det_idx, det_class in enumerate(detection_classes):
368+
if det_idx not in matched_det_idx:
369+
result_matrix[num_classes, det_class] += 1
370+
371+
return result_matrix
320372

321373
@staticmethod
322374
def _drop_extra_matches(

src/supervision/metrics/precision.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,12 @@ def _compute_precision(
385385
false_positives = confusion_matrix[..., 1]
386386

387387
denominator = true_positives + false_positives
388-
precision = np.where(denominator == 0, 0, true_positives / denominator)
388+
precision = np.divide(
389+
true_positives,
390+
denominator,
391+
out=np.zeros_like(true_positives),
392+
where=denominator != 0,
393+
)
389394

390395
result_precision: npt.NDArray[np.float64] = precision
391396
return result_precision

src/supervision/metrics/recall.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ def _compute_recall(
383383
false_negatives = confusion_matrix[..., 2]
384384

385385
denominator = true_positives + false_negatives
386-
recall = np.where(denominator == 0, 0, true_positives / denominator)
386+
recall = np.divide(
387+
true_positives,
388+
denominator,
389+
out=np.zeros_like(true_positives),
390+
where=denominator != 0,
391+
)
387392

388393
result_recall: npt.NDArray[np.float64] = recall
389394
return result_recall

tests/helpers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,67 @@ def create_yolo_dataset(
414414
"image_size": image_size,
415415
"image_annotations": image_annotations,
416416
}
417+
418+
419+
def create_predictions_with_class_iou_tests(
420+
gt_detections: Detections, num_classes: int
421+
) -> Detections:
422+
"""
423+
Create predictions that test IoU+class matching behavior.
424+
425+
For each ground truth detection, creates predictions with different patterns:
426+
- Pattern 0 (i%3==0): Correct match (same bbox, correct class)
427+
- Pattern 1 (i%3==1): Wrong class with perfect IoU + correct class with offset
428+
- Pattern 2 (i%3==2): Correct class with slight offset
429+
430+
This tests that predictions with wrong class don't match even with high IoU,
431+
which is the key fix in the confusion matrix calculation.
432+
433+
Args:
434+
gt_detections: Ground truth detections to create predictions for
435+
num_classes: Total number of classes in the dataset
436+
437+
Returns:
438+
Detections object with predictions designed to test IoU+class matching
439+
"""
440+
if len(gt_detections) == 0:
441+
# No ground truth, return a single false positive
442+
return _create_detections(
443+
xyxy=[[10, 10, 50, 50]], class_id=[0], confidence=[0.9]
444+
)
445+
446+
pred_boxes = []
447+
pred_classes = []
448+
pred_confs = []
449+
450+
for i, (box, cls) in enumerate(zip(gt_detections.xyxy, gt_detections.class_id)):
451+
if i % 3 == 0:
452+
# Pattern 1: Correct match
453+
pred_boxes.append(box)
454+
pred_classes.append(cls)
455+
pred_confs.append(0.95)
456+
457+
elif i % 3 == 1:
458+
# Pattern 2: Test the fix - add wrong class prediction with perfect IoU,
459+
# then correct class with slightly offset bbox
460+
wrong_cls = (cls + 1) % num_classes
461+
pred_boxes.append(box) # Perfect IoU
462+
pred_classes.append(wrong_cls) # Wrong class
463+
pred_confs.append(0.90)
464+
465+
# Add correct class with slight offset
466+
offset_box = box + np.array([2, 2, 2, 2], dtype=np.float32)
467+
pred_boxes.append(offset_box)
468+
pred_classes.append(cls) # Correct class
469+
pred_confs.append(0.85)
470+
471+
else:
472+
# Pattern 3: Correct match with slight offset
473+
offset_box = box + np.array([1, 1, 1, 1], dtype=np.float32)
474+
pred_boxes.append(offset_box)
475+
pred_classes.append(cls)
476+
pred_confs.append(0.92)
477+
478+
return _create_detections(
479+
xyxy=pred_boxes, class_id=pred_classes, confidence=pred_confs
480+
)

tests/metrics/conftest.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
import pytest
35

@@ -138,3 +140,85 @@ def target_class_1():
138140
xyxy=np.array([[60, 60, 100, 100]], dtype=np.float32),
139141
class_id=np.array([1]),
140142
)
143+
144+
145+
def _yolo_dataset_factory(
146+
tmp_path,
147+
num_images: int = 20,
148+
classes: Optional[list[str]] = None,
149+
objects_per_image_range: tuple[int, int] = (1, 3),
150+
):
151+
"""
152+
Factory function to create synthetic YOLO-format datasets with custom parameters.
153+
154+
Args:
155+
tmp_path: Pytest tmp_path fixture
156+
num_images: Number of images to generate
157+
classes: List of class names
158+
objects_per_image_range: Range of objects per image as (min, max)
159+
160+
Returns:
161+
dict with dataset paths and metadata
162+
"""
163+
from tests.helpers import create_yolo_dataset
164+
165+
if classes is None:
166+
classes = ["dog", "cat", "person"]
167+
168+
return create_yolo_dataset(
169+
dataset_dir=str(tmp_path / "yolo_dataset"),
170+
num_images=num_images,
171+
image_size=(640, 640, 3),
172+
classes=classes,
173+
objects_per_image_range=objects_per_image_range,
174+
seed=42,
175+
)
176+
177+
178+
@pytest.fixture
179+
def yolo_dataset_structure(tmp_path):
180+
"""
181+
Synthetic YOLO-format dataset for testing confusion matrix and detection metrics.
182+
183+
Configuration:
184+
- 20 images
185+
- 640x640 resolution
186+
- 3 classes: ["dog", "cat", "person"]
187+
- 1-3 objects per image
188+
189+
Use this for tests that need multi-class scenarios (3+ classes).
190+
191+
Returns:
192+
dict with dataset paths and metadata
193+
"""
194+
return _yolo_dataset_factory(
195+
tmp_path,
196+
num_images=20,
197+
classes=["dog", "cat", "person"],
198+
objects_per_image_range=(1, 3),
199+
)
200+
201+
202+
@pytest.fixture
203+
def yolo_dataset_two_classes(tmp_path):
204+
"""
205+
Synthetic YOLO-format dataset for testing mAR and binary classification metrics.
206+
207+
Configuration:
208+
- 15 images
209+
- 640x640 resolution
210+
- 2 classes: ["class_0", "class_1"]
211+
- 2-4 objects per image
212+
213+
Use this for tests that specifically need 2-class scenarios or depend on
214+
specific class distributions (e.g., mAR @ K per-image limiting tests).
215+
216+
Returns:
217+
dict with dataset paths and metadata
218+
"""
219+
return _yolo_dataset_factory(
220+
tmp_path,
221+
num_images=15,
222+
classes=["class_0", "class_1"],
223+
objects_per_image_range=(2, 4),
224+
)

0 commit comments

Comments
 (0)