@@ -32,7 +32,6 @@ def __init__(
3232 """
3333 super ().__init__ ()
3434
35- self .n_anchors = 1
3635 self .num_classes = num_classes
3736 self .decode_in_inference = True # for deploy, set to False
3837
@@ -97,7 +96,7 @@ def __init__(
9796 self .cls_preds .append (
9897 nn .Conv2d (
9998 in_channels = int (256 * width ),
100- out_channels = self .n_anchors * self . num_classes ,
99+ out_channels = self .num_classes ,
101100 kernel_size = 1 ,
102101 stride = 1 ,
103102 padding = 0 ,
@@ -115,7 +114,7 @@ def __init__(
115114 self .obj_preds .append (
116115 nn .Conv2d (
117116 in_channels = int (256 * width ),
118- out_channels = self . n_anchors * 1 ,
117+ out_channels = 1 ,
119118 kernel_size = 1 ,
120119 stride = 1 ,
121120 padding = 0 ,
@@ -131,12 +130,12 @@ def __init__(
131130
132131 def initialize_biases (self , prior_prob ):
133132 for conv in self .cls_preds :
134- b = conv .bias .view (self . n_anchors , - 1 )
133+ b = conv .bias .view (1 , - 1 )
135134 b .data .fill_ (- math .log ((1 - prior_prob ) / prior_prob ))
136135 conv .bias = torch .nn .Parameter (b .view (- 1 ), requires_grad = True )
137136
138137 for conv in self .obj_preds :
139- b = conv .bias .view (self . n_anchors , - 1 )
138+ b = conv .bias .view (1 , - 1 )
140139 b .data .fill_ (- math .log ((1 - prior_prob ) / prior_prob ))
141140 conv .bias = torch .nn .Parameter (b .view (- 1 ), requires_grad = True )
142141
@@ -177,7 +176,7 @@ def forward(self, xin, labels=None, imgs=None):
177176 batch_size = reg_output .shape [0 ]
178177 hsize , wsize = reg_output .shape [- 2 :]
179178 reg_output = reg_output .view (
180- batch_size , self . n_anchors , 4 , hsize , wsize
179+ batch_size , 1 , 4 , hsize , wsize
181180 )
182181 reg_output = reg_output .permute (0 , 1 , 3 , 4 , 2 ).reshape (
183182 batch_size , - 1 , 4
@@ -224,9 +223,9 @@ def get_output_and_grid(self, output, k, stride, dtype):
224223 grid = torch .stack ((xv , yv ), 2 ).view (1 , 1 , hsize , wsize , 2 ).type (dtype )
225224 self .grids [k ] = grid
226225
227- output = output .view (batch_size , self . n_anchors , n_ch , hsize , wsize )
226+ output = output .view (batch_size , 1 , n_ch , hsize , wsize )
228227 output = output .permute (0 , 1 , 3 , 4 , 2 ).reshape (
229- batch_size , self . n_anchors * hsize * wsize , - 1
228+ batch_size , hsize * wsize , - 1
230229 )
231230 grid = grid .view (1 , - 1 , 2 )
232231 output [..., :2 ] = (output [..., :2 ] + grid ) * stride
@@ -265,7 +264,7 @@ def get_losses(
265264 dtype ,
266265 ):
267266 bbox_preds = outputs [:, :, :4 ] # [batch, n_anchors_all, 4]
268- obj_preds = outputs [:, :, 4 ]. unsqueeze ( - 1 ) # [batch, n_anchors_all, 1]
267+ obj_preds = outputs [:, :, 4 : 5 ] # [batch, n_anchors_all, 1]
269268 cls_preds = outputs [:, :, 5 :] # [batch, n_anchors_all, n_cls]
270269
271270 # calculate targets
@@ -311,18 +310,14 @@ def get_losses(
311310 ) = self .get_assignments ( # noqa
312311 batch_idx ,
313312 num_gt ,
314- total_num_anchors ,
315313 gt_bboxes_per_image ,
316314 gt_classes ,
317315 bboxes_preds_per_image ,
318316 expanded_strides ,
319317 x_shifts ,
320318 y_shifts ,
321319 cls_preds ,
322- bbox_preds ,
323320 obj_preds ,
324- labels ,
325- imgs ,
326321 )
327322 except RuntimeError as e :
328323 # TODO: the string might change, consider a better way
@@ -344,18 +339,14 @@ def get_losses(
344339 ) = self .get_assignments ( # noqa
345340 batch_idx ,
346341 num_gt ,
347- total_num_anchors ,
348342 gt_bboxes_per_image ,
349343 gt_classes ,
350344 bboxes_preds_per_image ,
351345 expanded_strides ,
352346 x_shifts ,
353347 y_shifts ,
354348 cls_preds ,
355- bbox_preds ,
356349 obj_preds ,
357- labels ,
358- imgs ,
359350 "cpu" ,
360351 )
361352
@@ -433,37 +424,31 @@ def get_assignments(
433424 self ,
434425 batch_idx ,
435426 num_gt ,
436- total_num_anchors ,
437427 gt_bboxes_per_image ,
438428 gt_classes ,
439429 bboxes_preds_per_image ,
440430 expanded_strides ,
441431 x_shifts ,
442432 y_shifts ,
443433 cls_preds ,
444- bbox_preds ,
445434 obj_preds ,
446- labels ,
447- imgs ,
448435 mode = "gpu" ,
449436 ):
450437
451438 if mode == "cpu" :
452- print ("------------ CPU Mode for This Batch-------------" )
439+ print ("-----------Using CPU for the Current Batch-------------" )
453440 gt_bboxes_per_image = gt_bboxes_per_image .cpu ().float ()
454441 bboxes_preds_per_image = bboxes_preds_per_image .cpu ().float ()
455442 gt_classes = gt_classes .cpu ().float ()
456443 expanded_strides = expanded_strides .cpu ().float ()
457444 x_shifts = x_shifts .cpu ()
458445 y_shifts = y_shifts .cpu ()
459446
460- fg_mask , is_in_boxes_and_center = self .get_in_boxes_info (
447+ fg_mask , geometry_relation = self .get_geometry_constraint (
461448 gt_bboxes_per_image ,
462449 expanded_strides ,
463450 x_shifts ,
464451 y_shifts ,
465- total_num_anchors ,
466- num_gt ,
467452 )
468453
469454 bboxes_preds_per_image = bboxes_preds_per_image [fg_mask ]
@@ -480,8 +465,6 @@ def get_assignments(
480465 gt_cls_per_image = (
481466 F .one_hot (gt_classes .to (torch .int64 ), self .num_classes )
482467 .float ()
483- .unsqueeze (1 )
484- .repeat (1 , num_in_boxes_anchor , 1 )
485468 )
486469 pair_wise_ious_loss = - torch .log (pair_wise_ious + 1e-8 )
487470
@@ -490,26 +473,27 @@ def get_assignments(
490473
491474 with torch .cuda .amp .autocast (enabled = False ):
492475 cls_preds_ = (
493- cls_preds_ .float ().unsqueeze (0 ).repeat (num_gt , 1 , 1 ).sigmoid_ ()
494- * obj_preds_ .float ().unsqueeze (0 ).repeat (num_gt , 1 , 1 ).sigmoid_ ()
495- )
476+ cls_preds_ .float ().sigmoid_ () * obj_preds_ .float ().sigmoid_ ()
477+ ).sqrt ()
496478 pair_wise_cls_loss = F .binary_cross_entropy (
497- cls_preds_ .sqrt_ (), gt_cls_per_image , reduction = "none"
479+ cls_preds_ .unsqueeze (0 ).repeat (num_gt , 1 , 1 ),
480+ gt_cls_per_image .unsqueeze (1 ).repeat (1 , num_in_boxes_anchor , 1 ),
481+ reduction = "none"
498482 ).sum (- 1 )
499483 del cls_preds_
500484
501485 cost = (
502486 pair_wise_cls_loss
503487 + 3.0 * pair_wise_ious_loss
504- + 100000.0 * (~ is_in_boxes_and_center )
488+ + float ( 1e6 ) * (~ geometry_relation )
505489 )
506490
507491 (
508492 num_fg ,
509493 gt_matched_classes ,
510494 pred_ious_this_matching ,
511495 matched_gt_inds ,
512- ) = self .dynamic_k_matching (cost , pair_wise_ious , gt_classes , num_gt , fg_mask )
496+ ) = self .simota_matching (cost , pair_wise_ious , gt_classes , num_gt , fg_mask )
513497 del pair_wise_cls_loss , cost , pair_wise_ious , pair_wise_ious_loss
514498
515499 if mode == "cpu" :
@@ -526,101 +510,49 @@ def get_assignments(
526510 num_fg ,
527511 )
528512
529- def get_in_boxes_info (
513+ def get_geometry_constraint (
530514 self ,
531515 gt_bboxes_per_image ,
532516 expanded_strides ,
533517 x_shifts ,
534518 y_shifts ,
535- total_num_anchors ,
536- num_gt ,
537519 ):
520+ """
521+ Calculate whether the center of an object is located in a fixed range of
522+ an anchor. This is used to avert inappropriate matching. It can also reduce
523+ the number of candidate anchors so that the GPU memory is saved.
524+ """
538525 expanded_strides_per_image = expanded_strides [0 ]
539- x_shifts_per_image = x_shifts [0 ] * expanded_strides_per_image
540- y_shifts_per_image = y_shifts [0 ] * expanded_strides_per_image
541- x_centers_per_image = (
542- (x_shifts_per_image + 0.5 * expanded_strides_per_image )
543- .unsqueeze (0 )
544- .repeat (num_gt , 1 )
545- ) # [n_anchor] -> [n_gt, n_anchor]
546- y_centers_per_image = (
547- (y_shifts_per_image + 0.5 * expanded_strides_per_image )
548- .unsqueeze (0 )
549- .repeat (num_gt , 1 )
550- )
551-
552- gt_bboxes_per_image_l = (
553- (gt_bboxes_per_image [:, 0 ] - 0.5 * gt_bboxes_per_image [:, 2 ])
554- .unsqueeze (1 )
555- .repeat (1 , total_num_anchors )
556- )
557- gt_bboxes_per_image_r = (
558- (gt_bboxes_per_image [:, 0 ] + 0.5 * gt_bboxes_per_image [:, 2 ])
559- .unsqueeze (1 )
560- .repeat (1 , total_num_anchors )
561- )
562- gt_bboxes_per_image_t = (
563- (gt_bboxes_per_image [:, 1 ] - 0.5 * gt_bboxes_per_image [:, 3 ])
564- .unsqueeze (1 )
565- .repeat (1 , total_num_anchors )
566- )
567- gt_bboxes_per_image_b = (
568- (gt_bboxes_per_image [:, 1 ] + 0.5 * gt_bboxes_per_image [:, 3 ])
569- .unsqueeze (1 )
570- .repeat (1 , total_num_anchors )
571- )
526+ x_centers_per_image = ((x_shifts [0 ] + 0.5 ) * expanded_strides_per_image ).unsqueeze (0 )
527+ y_centers_per_image = ((y_shifts [0 ] + 0.5 ) * expanded_strides_per_image ).unsqueeze (0 )
572528
573- b_l = x_centers_per_image - gt_bboxes_per_image_l
574- b_r = gt_bboxes_per_image_r - x_centers_per_image
575- b_t = y_centers_per_image - gt_bboxes_per_image_t
576- b_b = gt_bboxes_per_image_b - y_centers_per_image
577- bbox_deltas = torch .stack ([b_l , b_t , b_r , b_b ], 2 )
578-
579- is_in_boxes = bbox_deltas .min (dim = - 1 ).values > 0.0
580- is_in_boxes_all = is_in_boxes .sum (dim = 0 ) > 0
581529 # in fixed center
582-
583- center_radius = 2.5
584-
585- gt_bboxes_per_image_l = (gt_bboxes_per_image [:, 0 ]).unsqueeze (1 ).repeat (
586- 1 , total_num_anchors
587- ) - center_radius * expanded_strides_per_image .unsqueeze (0 )
588- gt_bboxes_per_image_r = (gt_bboxes_per_image [:, 0 ]).unsqueeze (1 ).repeat (
589- 1 , total_num_anchors
590- ) + center_radius * expanded_strides_per_image .unsqueeze (0 )
591- gt_bboxes_per_image_t = (gt_bboxes_per_image [:, 1 ]).unsqueeze (1 ).repeat (
592- 1 , total_num_anchors
593- ) - center_radius * expanded_strides_per_image .unsqueeze (0 )
594- gt_bboxes_per_image_b = (gt_bboxes_per_image [:, 1 ]).unsqueeze (1 ).repeat (
595- 1 , total_num_anchors
596- ) + center_radius * expanded_strides_per_image .unsqueeze (0 )
530+ center_radius = 1.5
531+ center_dist = expanded_strides_per_image .unsqueeze (0 ) * center_radius
532+ gt_bboxes_per_image_l = (gt_bboxes_per_image [:, 0 :1 ]) - center_dist
533+ gt_bboxes_per_image_r = (gt_bboxes_per_image [:, 0 :1 ]) + center_dist
534+ gt_bboxes_per_image_t = (gt_bboxes_per_image [:, 1 :2 ]) - center_dist
535+ gt_bboxes_per_image_b = (gt_bboxes_per_image [:, 1 :2 ]) + center_dist
597536
598537 c_l = x_centers_per_image - gt_bboxes_per_image_l
599538 c_r = gt_bboxes_per_image_r - x_centers_per_image
600539 c_t = y_centers_per_image - gt_bboxes_per_image_t
601540 c_b = gt_bboxes_per_image_b - y_centers_per_image
602541 center_deltas = torch .stack ([c_l , c_t , c_r , c_b ], 2 )
603542 is_in_centers = center_deltas .min (dim = - 1 ).values > 0.0
604- is_in_centers_all = is_in_centers .sum (dim = 0 ) > 0
543+ anchor_filter = is_in_centers .sum (dim = 0 ) > 0
544+ geometry_relation = is_in_centers [:, anchor_filter ]
605545
606- # in boxes and in centers
607- is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
608-
609- is_in_boxes_and_center = (
610- is_in_boxes [:, is_in_boxes_anchor ] & is_in_centers [:, is_in_boxes_anchor ]
611- )
612- return is_in_boxes_anchor , is_in_boxes_and_center
546+ return anchor_filter , geometry_relation
613547
614- def dynamic_k_matching (self , cost , pair_wise_ious , gt_classes , num_gt , fg_mask ):
548+ def simota_matching (self , cost , pair_wise_ious , gt_classes , num_gt , fg_mask ):
615549 # Dynamic K
616550 # ---------------------------------------------------------------
617551 matching_matrix = torch .zeros_like (cost , dtype = torch .uint8 )
618552
619- ious_in_boxes_matrix = pair_wise_ious
620- n_candidate_k = min (10 , ious_in_boxes_matrix .size (1 ))
621- topk_ious , _ = torch .topk (ious_in_boxes_matrix , n_candidate_k , dim = 1 )
553+ n_candidate_k = min (10 , pair_wise_ious .size (1 ))
554+ topk_ious , _ = torch .topk (pair_wise_ious , n_candidate_k , dim = 1 )
622555 dynamic_ks = torch .clamp (topk_ious .sum (1 ).int (), min = 1 )
623- dynamic_ks = dynamic_ks .tolist ()
624556 for gt_idx in range (num_gt ):
625557 _ , pos_idx = torch .topk (
626558 cost [gt_idx ], k = dynamic_ks [gt_idx ], largest = False
@@ -630,11 +562,13 @@ def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
630562 del topk_ious , dynamic_ks , pos_idx
631563
632564 anchor_matching_gt = matching_matrix .sum (0 )
565+ # deal with the case that one anchor matches multiple ground-truths
633566 if anchor_matching_gt .max () > 1 :
634- _ , cost_argmin = torch .min (cost [:, anchor_matching_gt > 1 ], dim = 0 )
635- matching_matrix [:, anchor_matching_gt > 1 ] *= 0
636- matching_matrix [cost_argmin , anchor_matching_gt > 1 ] = 1
637- fg_mask_inboxes = matching_matrix .sum (0 ) > 0
567+ multiple_match_mask = anchor_matching_gt > 1
568+ _ , cost_argmin = torch .min (cost [:, multiple_match_mask ], dim = 0 )
569+ matching_matrix [:, multiple_match_mask ] *= 0
570+ matching_matrix [cost_argmin , multiple_match_mask ] = 1
571+ fg_mask_inboxes = anchor_matching_gt > 0
638572 num_fg = fg_mask_inboxes .sum ().item ()
639573
640574 fg_mask [fg_mask .clone ()] = fg_mask_inboxes
0 commit comments