Skip to content

Commit b3a80be

Browse files
authored
feat(tools): add assignment visualizer (Megvii-BaseDetection#1616)
feat(tools): add assignment visualizer (Megvii-BaseDetection#1616)
1 parent 4f8f1d7 commit b3a80be

File tree

10 files changed

+249
-15
lines changed

10 files changed

+249
-15
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ This repo is an implementation of PyTorch version YOLOX, there is also a [MegEng
1010
<img src="assets/git_fig.png" width="1000" >
1111

1212
## Updates!!
13-
* 【2022/04/14】 We suport jit compile op.
13+
* 【2023/02/28】 We support assignment visualization tool, see doc [here](./docs/assignment_visualization.md).
14+
* 【2022/04/14】 We support jit compile op.
1415
* 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
1516
* 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
1617
* 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
@@ -206,6 +207,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
206207
* [Training on custom data](docs/train_custom_data.md)
207208
* [Caching for custom data](docs/cache.md)
208209
* [Manipulating training image size](docs/manipulate_training_image_size.md)
210+
* [Assignment visualization](docs/assignment_visualization.md)
209211
* [Freezing model](docs/freeze_module.md)
210212

211213
</details>
@@ -243,8 +245,8 @@ If you use YOLOX in your research, please cite our work by using the following B
243245
}
244246
```
245247
## In memory of Dr. Jian Sun
246-
Without the guidance of [Dr. Sun Jian](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
247-
The passing away of Dr. Sun Jian is a great loss to the Computer Vision field. We have added this section here to express our remembrance and condolences to our captain Dr. Sun.
248+
Without the guidance of [Dr. Jian Sun](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
249+
The passing away of Dr. Jian is a huge loss to the Computer Vision field. We add this section here to express our remembrance and condolences to our captain Dr. Jian.
248250
It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.
249251

250252
<div align="center"><img src="assets/sunjian.png" width="200"></div>

assets/assignment.png

650 KB
Loading

docs/assignment_visualization.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Visualize label assignment
2+
3+
This tutorial explains how to visualize your label asssignment result when training with YOLOX.
4+
5+
## 1. Visualization command
6+
7+
We provide a visualization tool to help you visualize your label assignment result. You can find it in [`tools/visualize_assignment.py`](../tools/visualize_assign.py).
8+
9+
Here is an example of command to visualize your label assignment result:
10+
11+
```shell
12+
python3 tools/visualize_assign.py -f /path/to/your/exp.py yolox-s -d 1 -b 8 --max-batch 2
13+
```
14+
15+
`max-batch` here means the maximum number of batches to visualize. The default value is 1, which the tool means only visualize the first batch.
16+
17+
By the way, the mosaic augmentation is used in default dataloader, so you can also see the mosaic result here.
18+
19+
After running the command, the logger will show you where the visualization result is saved, let's open it and into the step 2.
20+
21+
## 2. Check the visualization result
22+
23+
Here is an example of visualization result:
24+
<div align="center"><img src="../assets/assignment.png" width="640"></div>
25+
26+
Those dots in one box is the matched anchor of gt box. **The color of dots is the same as the color of the box** to help you determine which object is assigned to the anchor. Note the box and dots are **instance level** visualization, which means the same class may have different colors.
27+
**If the gt box doesn't match any anchor, the box will be marked as red and the red text "unmatched" will be drawn over the box**.
28+
29+
Please feel free to open an issue if you have any questions.

tools/visualize_assign.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Megvii, Inc. and its affiliates.
3+
4+
import os
5+
import sys
6+
import random
7+
import time
8+
import warnings
9+
from loguru import logger
10+
11+
import torch
12+
import torch.backends.cudnn as cudnn
13+
14+
from yolox.exp import Exp, get_exp
15+
from yolox.core import Trainer
16+
from yolox.utils import configure_module, configure_omp
17+
from yolox.tools.train import make_parser
18+
19+
20+
class AssignVisualizer(Trainer):
21+
22+
def __init__(self, exp: Exp, args):
23+
super().__init__(exp, args)
24+
self.batch_cnt = 0
25+
self.vis_dir = os.path.join(self.file_name, "vis")
26+
os.makedirs(self.vis_dir, exist_ok=True)
27+
28+
def train_one_iter(self):
29+
iter_start_time = time.time()
30+
31+
inps, targets = self.prefetcher.next()
32+
inps = inps.to(self.data_type)
33+
targets = targets.to(self.data_type)
34+
targets.requires_grad = False
35+
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
36+
data_end_time = time.time()
37+
38+
with torch.cuda.amp.autocast(enabled=self.amp_training):
39+
path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_")
40+
self.model.visualize(inps, targets, path_prefix)
41+
42+
if self.use_model_ema:
43+
self.ema_model.update(self.model)
44+
45+
iter_end_time = time.time()
46+
self.meter.update(
47+
iter_time=iter_end_time - iter_start_time,
48+
data_time=data_end_time - iter_start_time,
49+
)
50+
self.batch_cnt += 1
51+
if self.batch_cnt >= self.args.max_batch:
52+
sys.exit(0)
53+
54+
def after_train(self):
55+
logger.info("Finish visualize assignment, exit...")
56+
57+
58+
def assign_vis_parser():
59+
parser = make_parser()
60+
parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize")
61+
return parser
62+
63+
64+
@logger.catch
65+
def main(exp: Exp, args):
66+
if exp.seed is not None:
67+
random.seed(exp.seed)
68+
torch.manual_seed(exp.seed)
69+
cudnn.deterministic = True
70+
warnings.warn(
71+
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
72+
"which can slow down your training considerably! You may see unexpected behavior "
73+
"when restarting from checkpoints."
74+
)
75+
76+
# set environment variables for distributed training
77+
configure_omp()
78+
cudnn.benchmark = True
79+
80+
visualizer = AssignVisualizer(exp, args)
81+
visualizer.train()
82+
83+
84+
if __name__ == "__main__":
85+
configure_module()
86+
args = assign_vis_parser().parse_args()
87+
exp = get_exp(args.exp_file, args.name)
88+
exp.merge(args.opts)
89+
90+
if not args.experiment_name:
91+
args.experiment_name = exp.exp_name
92+
93+
main(exp, args)

yolox/core/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii, Inc. and its affiliates.
43

54
import datetime

yolox/models/yolo_head.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010
import torch.nn.functional as F
1111

12-
from yolox.utils import bboxes_iou, meshgrid
12+
from yolox.utils import bboxes_iou, cxcywh2xyxy, meshgrid, visualize_assign
1313

1414
from .losses import IOUloss
1515
from .network_blocks import BaseConv, DWConv
@@ -511,11 +511,7 @@ def get_assignments(
511511
)
512512

513513
def get_geometry_constraint(
514-
self,
515-
gt_bboxes_per_image,
516-
expanded_strides,
517-
x_shifts,
518-
y_shifts,
514+
self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
519515
):
520516
"""
521517
Calculate whether the center of an object is located in a fixed range of
@@ -546,8 +542,6 @@ def get_geometry_constraint(
546542
return anchor_filter, geometry_relation
547543

548544
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
549-
# Dynamic K
550-
# ---------------------------------------------------------------
551545
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
552546

553547
n_candidate_k = min(10, pair_wise_ious.size(1))
@@ -580,3 +574,68 @@ def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
580574
fg_mask_inboxes
581575
]
582576
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
577+
578+
def visualize_assign_result(self, xin, labels=None, imgs=None, save_prefix="assign_vis_"):
579+
# original forward logic
580+
outputs, x_shifts, y_shifts, expanded_strides = [], [], [], []
581+
# TODO: use forward logic here.
582+
583+
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
584+
zip(self.cls_convs, self.reg_convs, self.strides, xin)
585+
):
586+
x = self.stems[k](x)
587+
cls_x = x
588+
reg_x = x
589+
590+
cls_feat = cls_conv(cls_x)
591+
cls_output = self.cls_preds[k](cls_feat)
592+
reg_feat = reg_conv(reg_x)
593+
reg_output = self.reg_preds[k](reg_feat)
594+
obj_output = self.obj_preds[k](reg_feat)
595+
596+
output = torch.cat([reg_output, obj_output, cls_output], 1)
597+
output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
598+
x_shifts.append(grid[:, :, 0])
599+
y_shifts.append(grid[:, :, 1])
600+
expanded_strides.append(
601+
torch.full((1, grid.shape[1]), stride_this_level).type_as(xin[0])
602+
)
603+
outputs.append(output)
604+
605+
outputs = torch.cat(outputs, 1)
606+
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
607+
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
608+
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
609+
610+
# calculate targets
611+
total_num_anchors = outputs.shape[1]
612+
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
613+
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
614+
expanded_strides = torch.cat(expanded_strides, 1)
615+
616+
nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
617+
for batch_idx, (img, num_gt, label) in enumerate(zip(imgs, nlabel, labels)):
618+
img = imgs[batch_idx].permute(1, 2, 0).to(torch.uint8)
619+
num_gt = int(num_gt)
620+
if num_gt == 0:
621+
fg_mask = outputs.new_zeros(total_num_anchors).bool()
622+
else:
623+
gt_bboxes_per_image = label[:num_gt, 1:5]
624+
gt_classes = label[:num_gt, 0]
625+
bboxes_preds_per_image = bbox_preds[batch_idx]
626+
_, fg_mask, _, matched_gt_inds, _ = self.get_assignments( # noqa
627+
batch_idx, num_gt, gt_bboxes_per_image, gt_classes,
628+
bboxes_preds_per_image, expanded_strides, x_shifts,
629+
y_shifts, cls_preds, obj_preds,
630+
)
631+
632+
img = img.cpu().numpy().copy() # copy is crucial here
633+
coords = torch.stack([
634+
((x_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
635+
((y_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
636+
], 1)
637+
638+
xyxy_boxes = cxcywh2xyxy(gt_bboxes_per_image)
639+
save_name = save_prefix + str(batch_idx) + ".png"
640+
img = visualize_assign(img, xyxy_boxes, coords, matched_gt_inds, save_name)
641+
logger.info(f"save img to {save_name}")

yolox/models/yolox.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@ def forward(self, x, targets=None):
4646
outputs = self.head(fpn_outs)
4747

4848
return outputs
49+
50+
def visualize(self, x, targets, save_prefix="assign_vis_"):
51+
fpn_outs = self.backbone(x)
52+
self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)

yolox/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
from .allreduce_norm import *

yolox/utils/boxes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
import numpy as np
@@ -15,6 +14,7 @@
1514
"adjust_box_anns",
1615
"xyxy2xywh",
1716
"xyxy2cxcywh",
17+
"cxcywh2xyxy",
1818
]
1919

2020

@@ -133,3 +133,11 @@ def xyxy2cxcywh(bboxes):
133133
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
134134
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
135135
return bboxes
136+
137+
138+
def cxcywh2xyxy(bboxes):
139+
bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
140+
bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
141+
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
142+
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
143+
return bboxes

yolox/utils/demo_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,51 @@
22
# Copyright (c) Megvii Inc. All rights reserved.
33

44
import os
5+
import random
56

7+
import cv2
68
import numpy as np
79

8-
__all__ = ["mkdir", "nms", "multiclass_nms", "demo_postprocess"]
10+
__all__ = [
11+
"mkdir", "nms", "multiclass_nms", "demo_postprocess", "random_color", "visualize_assign"
12+
]
13+
14+
15+
def random_color():
16+
return random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
17+
18+
19+
def visualize_assign(img, boxes, coords, match_results, save_name=None) -> np.ndarray:
20+
"""visualize label assign result.
21+
22+
Args:
23+
img: img to visualize
24+
boxes: gt boxes in xyxy format
25+
coords: coords of matched anchors
26+
match_results: match results of each gt box and coord.
27+
save_name: name of save image, if None, image will not be saved. Default: None.
28+
"""
29+
for box_id, box in enumerate(boxes):
30+
x1, y1, x2, y2 = box
31+
color = random_color()
32+
assign_coords = coords[match_results == box_id]
33+
if assign_coords.numel() == 0:
34+
# unmatched boxes are red
35+
color = (0, 0, 255)
36+
cv2.putText(
37+
img, "unmatched", (int(x1), int(y1) - 5),
38+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1
39+
)
40+
else:
41+
for coord in assign_coords:
42+
# draw assigned anchor
43+
cv2.circle(img, (int(coord[0]), int(coord[1])), 3, color, -1)
44+
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
45+
46+
if save_name is not None:
47+
cv2.imwrite(save_name, img)
48+
49+
return img
950

1051

1152
def mkdir(path):

0 commit comments

Comments
 (0)