|
9 | 9 | import torch.nn as nn |
10 | 10 | import torch.nn.functional as F |
11 | 11 |
|
12 | | -from yolox.utils import bboxes_iou, meshgrid |
| 12 | +from yolox.utils import bboxes_iou, cxcywh2xyxy, meshgrid, visualize_assign |
13 | 13 |
|
14 | 14 | from .losses import IOUloss |
15 | 15 | from .network_blocks import BaseConv, DWConv |
@@ -511,11 +511,7 @@ def get_assignments( |
511 | 511 | ) |
512 | 512 |
|
513 | 513 | def get_geometry_constraint( |
514 | | - self, |
515 | | - gt_bboxes_per_image, |
516 | | - expanded_strides, |
517 | | - x_shifts, |
518 | | - y_shifts, |
| 514 | + self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, |
519 | 515 | ): |
520 | 516 | """ |
521 | 517 | Calculate whether the center of an object is located in a fixed range of |
@@ -546,8 +542,6 @@ def get_geometry_constraint( |
546 | 542 | return anchor_filter, geometry_relation |
547 | 543 |
|
548 | 544 | def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): |
549 | | - # Dynamic K |
550 | | - # --------------------------------------------------------------- |
551 | 545 | matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) |
552 | 546 |
|
553 | 547 | n_candidate_k = min(10, pair_wise_ious.size(1)) |
@@ -580,3 +574,68 @@ def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): |
580 | 574 | fg_mask_inboxes |
581 | 575 | ] |
582 | 576 | return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds |
| 577 | + |
| 578 | + def visualize_assign_result(self, xin, labels=None, imgs=None, save_prefix="assign_vis_"): |
| 579 | + # original forward logic |
| 580 | + outputs, x_shifts, y_shifts, expanded_strides = [], [], [], [] |
| 581 | + # TODO: use forward logic here. |
| 582 | + |
| 583 | + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( |
| 584 | + zip(self.cls_convs, self.reg_convs, self.strides, xin) |
| 585 | + ): |
| 586 | + x = self.stems[k](x) |
| 587 | + cls_x = x |
| 588 | + reg_x = x |
| 589 | + |
| 590 | + cls_feat = cls_conv(cls_x) |
| 591 | + cls_output = self.cls_preds[k](cls_feat) |
| 592 | + reg_feat = reg_conv(reg_x) |
| 593 | + reg_output = self.reg_preds[k](reg_feat) |
| 594 | + obj_output = self.obj_preds[k](reg_feat) |
| 595 | + |
| 596 | + output = torch.cat([reg_output, obj_output, cls_output], 1) |
| 597 | + output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type()) |
| 598 | + x_shifts.append(grid[:, :, 0]) |
| 599 | + y_shifts.append(grid[:, :, 1]) |
| 600 | + expanded_strides.append( |
| 601 | + torch.full((1, grid.shape[1]), stride_this_level).type_as(xin[0]) |
| 602 | + ) |
| 603 | + outputs.append(output) |
| 604 | + |
| 605 | + outputs = torch.cat(outputs, 1) |
| 606 | + bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4] |
| 607 | + obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1] |
| 608 | + cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] |
| 609 | + |
| 610 | + # calculate targets |
| 611 | + total_num_anchors = outputs.shape[1] |
| 612 | + x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] |
| 613 | + y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] |
| 614 | + expanded_strides = torch.cat(expanded_strides, 1) |
| 615 | + |
| 616 | + nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects |
| 617 | + for batch_idx, (img, num_gt, label) in enumerate(zip(imgs, nlabel, labels)): |
| 618 | + img = imgs[batch_idx].permute(1, 2, 0).to(torch.uint8) |
| 619 | + num_gt = int(num_gt) |
| 620 | + if num_gt == 0: |
| 621 | + fg_mask = outputs.new_zeros(total_num_anchors).bool() |
| 622 | + else: |
| 623 | + gt_bboxes_per_image = label[:num_gt, 1:5] |
| 624 | + gt_classes = label[:num_gt, 0] |
| 625 | + bboxes_preds_per_image = bbox_preds[batch_idx] |
| 626 | + _, fg_mask, _, matched_gt_inds, _ = self.get_assignments( # noqa |
| 627 | + batch_idx, num_gt, gt_bboxes_per_image, gt_classes, |
| 628 | + bboxes_preds_per_image, expanded_strides, x_shifts, |
| 629 | + y_shifts, cls_preds, obj_preds, |
| 630 | + ) |
| 631 | + |
| 632 | + img = img.cpu().numpy().copy() # copy is crucial here |
| 633 | + coords = torch.stack([ |
| 634 | + ((x_shifts + 0.5) * expanded_strides).flatten()[fg_mask], |
| 635 | + ((y_shifts + 0.5) * expanded_strides).flatten()[fg_mask], |
| 636 | + ], 1) |
| 637 | + |
| 638 | + xyxy_boxes = cxcywh2xyxy(gt_bboxes_per_image) |
| 639 | + save_name = save_prefix + str(batch_idx) + ".png" |
| 640 | + img = visualize_assign(img, xyxy_boxes, coords, matched_gt_inds, save_name) |
| 641 | + logger.info(f"save img to {save_name}") |
0 commit comments