Skip to content

Commit 14c62a7

Browse files
fix(models): no-candidate anchor issue for tiny objects during label assign (Megvii-BaseDetection#1589)
fix(models): no-candidate anchor issue for tiny objects during label assign
1 parent a4152a5 commit 14c62a7

File tree

1 file changed

+43
-109
lines changed

1 file changed

+43
-109
lines changed

yolox/models/yolo_head.py

Lines changed: 43 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def __init__(
3232
"""
3333
super().__init__()
3434

35-
self.n_anchors = 1
3635
self.num_classes = num_classes
3736
self.decode_in_inference = True # for deploy, set to False
3837

@@ -97,7 +96,7 @@ def __init__(
9796
self.cls_preds.append(
9897
nn.Conv2d(
9998
in_channels=int(256 * width),
100-
out_channels=self.n_anchors * self.num_classes,
99+
out_channels=self.num_classes,
101100
kernel_size=1,
102101
stride=1,
103102
padding=0,
@@ -115,7 +114,7 @@ def __init__(
115114
self.obj_preds.append(
116115
nn.Conv2d(
117116
in_channels=int(256 * width),
118-
out_channels=self.n_anchors * 1,
117+
out_channels=1,
119118
kernel_size=1,
120119
stride=1,
121120
padding=0,
@@ -131,12 +130,12 @@ def __init__(
131130

132131
def initialize_biases(self, prior_prob):
133132
for conv in self.cls_preds:
134-
b = conv.bias.view(self.n_anchors, -1)
133+
b = conv.bias.view(1, -1)
135134
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
136135
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
137136

138137
for conv in self.obj_preds:
139-
b = conv.bias.view(self.n_anchors, -1)
138+
b = conv.bias.view(1, -1)
140139
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
141140
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
142141

@@ -177,7 +176,7 @@ def forward(self, xin, labels=None, imgs=None):
177176
batch_size = reg_output.shape[0]
178177
hsize, wsize = reg_output.shape[-2:]
179178
reg_output = reg_output.view(
180-
batch_size, self.n_anchors, 4, hsize, wsize
179+
batch_size, 1, 4, hsize, wsize
181180
)
182181
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
183182
batch_size, -1, 4
@@ -224,9 +223,9 @@ def get_output_and_grid(self, output, k, stride, dtype):
224223
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
225224
self.grids[k] = grid
226225

227-
output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
226+
output = output.view(batch_size, 1, n_ch, hsize, wsize)
228227
output = output.permute(0, 1, 3, 4, 2).reshape(
229-
batch_size, self.n_anchors * hsize * wsize, -1
228+
batch_size, hsize * wsize, -1
230229
)
231230
grid = grid.view(1, -1, 2)
232231
output[..., :2] = (output[..., :2] + grid) * stride
@@ -265,7 +264,7 @@ def get_losses(
265264
dtype,
266265
):
267266
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
268-
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
267+
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
269268
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
270269

271270
# calculate targets
@@ -311,18 +310,14 @@ def get_losses(
311310
) = self.get_assignments( # noqa
312311
batch_idx,
313312
num_gt,
314-
total_num_anchors,
315313
gt_bboxes_per_image,
316314
gt_classes,
317315
bboxes_preds_per_image,
318316
expanded_strides,
319317
x_shifts,
320318
y_shifts,
321319
cls_preds,
322-
bbox_preds,
323320
obj_preds,
324-
labels,
325-
imgs,
326321
)
327322
except RuntimeError as e:
328323
# TODO: the string might change, consider a better way
@@ -344,18 +339,14 @@ def get_losses(
344339
) = self.get_assignments( # noqa
345340
batch_idx,
346341
num_gt,
347-
total_num_anchors,
348342
gt_bboxes_per_image,
349343
gt_classes,
350344
bboxes_preds_per_image,
351345
expanded_strides,
352346
x_shifts,
353347
y_shifts,
354348
cls_preds,
355-
bbox_preds,
356349
obj_preds,
357-
labels,
358-
imgs,
359350
"cpu",
360351
)
361352

@@ -433,37 +424,31 @@ def get_assignments(
433424
self,
434425
batch_idx,
435426
num_gt,
436-
total_num_anchors,
437427
gt_bboxes_per_image,
438428
gt_classes,
439429
bboxes_preds_per_image,
440430
expanded_strides,
441431
x_shifts,
442432
y_shifts,
443433
cls_preds,
444-
bbox_preds,
445434
obj_preds,
446-
labels,
447-
imgs,
448435
mode="gpu",
449436
):
450437

451438
if mode == "cpu":
452-
print("------------CPU Mode for This Batch-------------")
439+
print("-----------Using CPU for the Current Batch-------------")
453440
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
454441
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
455442
gt_classes = gt_classes.cpu().float()
456443
expanded_strides = expanded_strides.cpu().float()
457444
x_shifts = x_shifts.cpu()
458445
y_shifts = y_shifts.cpu()
459446

460-
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
447+
fg_mask, geometry_relation = self.get_geometry_constraint(
461448
gt_bboxes_per_image,
462449
expanded_strides,
463450
x_shifts,
464451
y_shifts,
465-
total_num_anchors,
466-
num_gt,
467452
)
468453

469454
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
@@ -480,8 +465,6 @@ def get_assignments(
480465
gt_cls_per_image = (
481466
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
482467
.float()
483-
.unsqueeze(1)
484-
.repeat(1, num_in_boxes_anchor, 1)
485468
)
486469
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
487470

@@ -490,26 +473,27 @@ def get_assignments(
490473

491474
with torch.cuda.amp.autocast(enabled=False):
492475
cls_preds_ = (
493-
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
494-
* obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
495-
)
476+
cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
477+
).sqrt()
496478
pair_wise_cls_loss = F.binary_cross_entropy(
497-
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
479+
cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
480+
gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
481+
reduction="none"
498482
).sum(-1)
499483
del cls_preds_
500484

501485
cost = (
502486
pair_wise_cls_loss
503487
+ 3.0 * pair_wise_ious_loss
504-
+ 100000.0 * (~is_in_boxes_and_center)
488+
+ float(1e6) * (~geometry_relation)
505489
)
506490

507491
(
508492
num_fg,
509493
gt_matched_classes,
510494
pred_ious_this_matching,
511495
matched_gt_inds,
512-
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
496+
) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
513497
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
514498

515499
if mode == "cpu":
@@ -526,101 +510,49 @@ def get_assignments(
526510
num_fg,
527511
)
528512

529-
def get_in_boxes_info(
513+
def get_geometry_constraint(
530514
self,
531515
gt_bboxes_per_image,
532516
expanded_strides,
533517
x_shifts,
534518
y_shifts,
535-
total_num_anchors,
536-
num_gt,
537519
):
520+
"""
521+
Calculate whether the center of an object is located in a fixed range of
522+
an anchor. This is used to avert inappropriate matching. It can also reduce
523+
the number of candidate anchors so that the GPU memory is saved.
524+
"""
538525
expanded_strides_per_image = expanded_strides[0]
539-
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
540-
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
541-
x_centers_per_image = (
542-
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
543-
.unsqueeze(0)
544-
.repeat(num_gt, 1)
545-
) # [n_anchor] -> [n_gt, n_anchor]
546-
y_centers_per_image = (
547-
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
548-
.unsqueeze(0)
549-
.repeat(num_gt, 1)
550-
)
551-
552-
gt_bboxes_per_image_l = (
553-
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
554-
.unsqueeze(1)
555-
.repeat(1, total_num_anchors)
556-
)
557-
gt_bboxes_per_image_r = (
558-
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
559-
.unsqueeze(1)
560-
.repeat(1, total_num_anchors)
561-
)
562-
gt_bboxes_per_image_t = (
563-
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
564-
.unsqueeze(1)
565-
.repeat(1, total_num_anchors)
566-
)
567-
gt_bboxes_per_image_b = (
568-
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
569-
.unsqueeze(1)
570-
.repeat(1, total_num_anchors)
571-
)
526+
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
527+
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
572528

573-
b_l = x_centers_per_image - gt_bboxes_per_image_l
574-
b_r = gt_bboxes_per_image_r - x_centers_per_image
575-
b_t = y_centers_per_image - gt_bboxes_per_image_t
576-
b_b = gt_bboxes_per_image_b - y_centers_per_image
577-
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
578-
579-
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
580-
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
581529
# in fixed center
582-
583-
center_radius = 2.5
584-
585-
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
586-
1, total_num_anchors
587-
) - center_radius * expanded_strides_per_image.unsqueeze(0)
588-
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
589-
1, total_num_anchors
590-
) + center_radius * expanded_strides_per_image.unsqueeze(0)
591-
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
592-
1, total_num_anchors
593-
) - center_radius * expanded_strides_per_image.unsqueeze(0)
594-
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
595-
1, total_num_anchors
596-
) + center_radius * expanded_strides_per_image.unsqueeze(0)
530+
center_radius = 1.5
531+
center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
532+
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
533+
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
534+
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
535+
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist
597536

598537
c_l = x_centers_per_image - gt_bboxes_per_image_l
599538
c_r = gt_bboxes_per_image_r - x_centers_per_image
600539
c_t = y_centers_per_image - gt_bboxes_per_image_t
601540
c_b = gt_bboxes_per_image_b - y_centers_per_image
602541
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
603542
is_in_centers = center_deltas.min(dim=-1).values > 0.0
604-
is_in_centers_all = is_in_centers.sum(dim=0) > 0
543+
anchor_filter = is_in_centers.sum(dim=0) > 0
544+
geometry_relation = is_in_centers[:, anchor_filter]
605545

606-
# in boxes and in centers
607-
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
608-
609-
is_in_boxes_and_center = (
610-
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
611-
)
612-
return is_in_boxes_anchor, is_in_boxes_and_center
546+
return anchor_filter, geometry_relation
613547

614-
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
548+
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
615549
# Dynamic K
616550
# ---------------------------------------------------------------
617551
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
618552

619-
ious_in_boxes_matrix = pair_wise_ious
620-
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
621-
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
553+
n_candidate_k = min(10, pair_wise_ious.size(1))
554+
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
622555
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
623-
dynamic_ks = dynamic_ks.tolist()
624556
for gt_idx in range(num_gt):
625557
_, pos_idx = torch.topk(
626558
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
@@ -630,11 +562,13 @@ def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
630562
del topk_ious, dynamic_ks, pos_idx
631563

632564
anchor_matching_gt = matching_matrix.sum(0)
565+
# deal with the case that one anchor matches multiple ground-truths
633566
if anchor_matching_gt.max() > 1:
634-
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
635-
matching_matrix[:, anchor_matching_gt > 1] *= 0
636-
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
637-
fg_mask_inboxes = matching_matrix.sum(0) > 0
567+
multiple_match_mask = anchor_matching_gt > 1
568+
_, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
569+
matching_matrix[:, multiple_match_mask] *= 0
570+
matching_matrix[cost_argmin, multiple_match_mask] = 1
571+
fg_mask_inboxes = anchor_matching_gt > 0
638572
num_fg = fg_mask_inboxes.sum().item()
639573

640574
fg_mask[fg_mask.clone()] = fg_mask_inboxes

0 commit comments

Comments
 (0)