Skip to content

Commit 124be19

Browse files
stop1onepre-commit-ci[bot]Borda
authored
Fix MeanAverageRecall compute mAR@K using top-K detections per image [COCO-compliant] (#2136)
* Supersedes #1967 * fix: COCO-compliant mAR calculation * Add complex test of mAP * fix(metrics): cast optional detections fields in mAR metric for mypy * Add `create_yolo_dataset` utility and refactor tests with fixtures for reusable scenarios * Refine docstrings for clarity and consistency, adding inline formatting and fixing typos in test helpers and metrics. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jirka <6035284+Borda@users.noreply.github.com>
1 parent 839d117 commit 124be19

File tree

3 files changed

+740
-17
lines changed

3 files changed

+740
-17
lines changed

src/supervision/metrics/mean_average_recall.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from copy import deepcopy
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import numpy as np
88
import numpy.typing as npt
@@ -376,7 +376,7 @@ def _compute(
376376
stats.append(
377377
(
378378
np.zeros((0, iou_thresholds.size), dtype=bool),
379-
np.zeros((0,), dtype=np.float32),
379+
np.zeros((0,), dtype=int),
380380
np.zeros((0,), dtype=int),
381381
targets.class_id,
382382
)
@@ -406,12 +406,18 @@ def _compute(
406406
iou,
407407
iou_thresholds,
408408
)
409+
410+
sorted_indices = np.argsort(
411+
-cast(npt.NDArray[np.float32], predictions.confidence)
412+
)
409413
stats.append(
410414
(
411-
matches,
412-
predictions.confidence,
413-
predictions.class_id,
414-
targets.class_id,
415+
matches[sorted_indices],
416+
np.arange(len(predictions)),
417+
cast(npt.NDArray[np.int32], predictions.class_id)[
418+
sorted_indices
419+
],
420+
cast(npt.NDArray[np.int32], targets.class_id),
415421
)
416422
)
417423

@@ -448,28 +454,24 @@ def _compute(
448454
def _compute_average_recall_for_classes(
449455
self,
450456
matches: npt.NDArray[np.bool_],
451-
prediction_confidence: npt.NDArray[np.float32],
457+
prediction_indices: npt.NDArray[np.int32],
452458
prediction_class_ids: npt.NDArray[np.int32],
453459
true_class_ids: npt.NDArray[np.int32],
454460
) -> tuple[
455461
npt.NDArray[np.float64],
456462
npt.NDArray[np.float64],
457463
npt.NDArray[np.int32],
458464
]:
459-
sorted_indices = np.argsort(-prediction_confidence)
460-
matches = matches[sorted_indices]
461-
prediction_class_ids = prediction_class_ids[sorted_indices]
462465
unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
463466

464467
recalls_at_k = []
465468
for max_detections in self.max_detections:
466469
# Shape: PxTh,P,C,C -> CxThx3
467470
confusion_matrix = self._compute_confusion_matrix(
468-
matches,
469-
prediction_class_ids,
471+
matches[prediction_indices < max_detections],
472+
prediction_class_ids[prediction_indices < max_detections],
470473
unique_classes,
471474
class_counts,
472-
max_detections=max_detections,
473475
)
474476

475477
# Shape: CxThx3 -> CxTh
@@ -522,7 +524,6 @@ def _compute_confusion_matrix(
522524
sorted_prediction_class_ids: npt.NDArray[np.int32],
523525
unique_classes: npt.NDArray[np.int32],
524526
class_counts: npt.NDArray[np.int32],
525-
max_detections: int | None = None,
526527
) -> npt.NDArray[np.float64]:
527528
"""
528529
Compute the confusion matrix for each class and IoU threshold.
@@ -567,7 +568,7 @@ class ids.
567568
false_positives = np.full(num_thresholds, num_predictions)
568569
false_negatives = np.zeros(num_thresholds)
569570
else:
570-
limited_matches = sorted_matches[is_class][slice(max_detections)]
571+
limited_matches = sorted_matches[is_class]
571572
true_positives = limited_matches.sum(0)
572573

573574
false_positives = (1 - limited_matches).sum(0)
@@ -641,8 +642,6 @@ def _make_empty_content(self) -> npt.NDArray[Any]:
641642

642643
raise ValueError(f"Invalid metric target: {self._metric_target}")
643644

644-
raise ValueError(f"Invalid metric target: {self._metric_target}")
645-
646645
def _filter_detections_by_size(
647646
self, detections: Detections, size_category: ObjectSizeCategory
648647
) -> Detections:

tests/helpers.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,114 @@ class _FakeYoloNasResults:
303303

304304
def __init__(self, prediction: _FakeYoloNasPrediction):
305305
self.prediction = prediction
306+
307+
308+
def create_yolo_dataset(
309+
dataset_dir: str,
310+
num_images: int = 15,
311+
image_size: tuple[int, int, int] = (640, 640, 3),
312+
classes: list[str] | None = None,
313+
objects_per_image_range: tuple[int, int] = (2, 4),
314+
seed: int = 42,
315+
) -> dict[str, Any]:
316+
"""
317+
Create a synthetic YOLO-format dataset on disk.
318+
319+
Generates dummy images with YOLO-format annotations, `data.yaml` file,
320+
and directory structure suitable for testing dataset loading.
321+
322+
Args:
323+
dataset_dir: Root directory path for the dataset.
324+
num_images: Number of images to generate.
325+
image_size: Image dimensions as `(width, height, channels)`.
326+
classes: List of class names. Defaults to `["class_0", "class_1"]`.
327+
objects_per_image_range: Range of objects per image as `(min, max)`.
328+
Actual count will cycle through this range.
329+
seed: Random seed for reproducibility.
330+
331+
Returns:
332+
Dictionary containing:
333+
- `tmpdir`: Root dataset directory path
334+
- `images_dir`: Images directory path
335+
- `labels_dir`: Labels directory path
336+
- `data_yaml_path`: `data.yaml` file path
337+
- `num_images`: Number of images created
338+
- `image_size`: Image dimensions
339+
- `image_annotations`: List of annotations per image
340+
341+
Examples:
342+
>>> from pathlib import Path
343+
>>> import tempfile
344+
>>> tmpdir = Path(tempfile.mkdtemp())
345+
>>> dataset_info = create_yolo_dataset(str(tmpdir), num_images=5)
346+
>>> dataset_info["num_images"]
347+
5
348+
>>> len(list(Path(dataset_info["images_dir"]).glob("*.jpg")))
349+
5
350+
"""
351+
from pathlib import Path
352+
353+
import cv2
354+
355+
if classes is None:
356+
classes = ["class_0", "class_1"]
357+
358+
np.random.seed(seed)
359+
360+
dataset_path = Path(dataset_dir)
361+
images_dir = dataset_path / "images"
362+
labels_dir = dataset_path / "labels"
363+
images_dir.mkdir(parents=True, exist_ok=True)
364+
labels_dir.mkdir(parents=True, exist_ok=True)
365+
366+
min_objects, max_objects = objects_per_image_range
367+
num_classes = len(classes)
368+
image_annotations = []
369+
370+
for i in range(num_images):
371+
# Create dummy image
372+
img_path = images_dir / f"image_{i:03d}.jpg"
373+
img = np.zeros(image_size, dtype=np.uint8)
374+
cv2.imwrite(str(img_path), img)
375+
376+
# Determine number of objects for this image
377+
num_objects = min_objects + (i % (max_objects - min_objects + 1))
378+
yolo_lines = []
379+
objects = []
380+
381+
for j in range(num_objects):
382+
class_id = j % num_classes
383+
# Random positions with spacing to avoid overlap
384+
x_center = 0.15 + (j * 0.25) + np.random.uniform(-0.05, 0.05)
385+
y_center = 0.15 + (j * 0.2) + np.random.uniform(-0.05, 0.05)
386+
width = 0.12
387+
height = 0.12
388+
389+
# Clip to valid range [0, 1]
390+
x_center = np.clip(x_center, width / 2, 1 - width / 2)
391+
y_center = np.clip(y_center, height / 2, 1 - height / 2)
392+
393+
yolo_lines.append(
394+
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
395+
)
396+
objects.append((class_id, x_center, y_center, width, height))
397+
398+
# Write YOLO annotation file
399+
label_path = labels_dir / f"image_{i:03d}.txt"
400+
label_path.write_text("".join(yolo_lines))
401+
image_annotations.append(objects)
402+
403+
# Create data.yaml
404+
data_yaml_path = dataset_path / "data.yaml"
405+
yaml_content = "names:\n" + "\n".join(f"- {cls}" for cls in classes) + "\n"
406+
data_yaml_path.write_text(yaml_content)
407+
408+
return {
409+
"tmpdir": dataset_path,
410+
"images_dir": str(images_dir),
411+
"labels_dir": str(labels_dir),
412+
"data_yaml_path": str(data_yaml_path),
413+
"num_images": num_images,
414+
"image_size": image_size,
415+
"image_annotations": image_annotations,
416+
}

0 commit comments

Comments
 (0)