Skip to content

Commit 48fd2a9

Browse files
authored
chore(logger): log predictions during training to wandb tables (Megvii-BaseDetection#1181)
chore(logger): log predictions during training to wandb tables
1 parent 6c68260 commit 48fd2a9

File tree

8 files changed

+257
-29
lines changed

8 files changed

+257
-29
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ On the second machine, run
150150
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
151151
```
152152

153+
**Logging to Weights & Biases**
154+
155+
To log metrics, predictions and model checkpoints to [W&B](https://docs.wandb.ai/guides/integrations/other/yolox) use the command line argument `--logger wandb` and use the prefix "wandb-" to specify arguments for initializing the wandb run.
156+
157+
```shell
158+
python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
159+
yolox-m
160+
yolox-l
161+
yolox-x
162+
```
163+
164+
An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
165+
153166
**Others**
154167
See more information with the following command:
155168
```shell

docs/quick_run.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
7676
yolox-x
7777
```
7878

79+
More WandbLogger arguments include
80+
81+
```shell
82+
python tools/train.py .... --logger wandb wandb-project <project-name> \
83+
wandb-name <run-name> \
84+
wandb-id <run-id> \
85+
wandb-save_dir <save-dir> \
86+
wandb-num_eval_images <num-images> \
87+
wandb-log_checkpoints <bool>
88+
```
89+
90+
More information available [here](https://docs.wandb.ai/guides/integrations/other/yolox).
91+
7992
**Multi Machine Training**
8093

8194
We also support multi-nodes training. Just add the following args:

tools/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def make_parser():
8484
"-l",
8585
"--logger",
8686
type=str,
87-
help="Logger to be used for metrics",
87+
help="Logger to be used for metrics. \
88+
Implemented loggers include `tensorboard` and `wandb`.",
8889
default="tensorboard"
8990
)
9091
parser.add_argument(

yolox/core/trainer.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def before_train(self):
180180
if self.args.logger == "tensorboard":
181181
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
182182
elif self.args.logger == "wandb":
183-
wandb_params = dict()
184-
for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
185-
if k.startswith("wandb-"):
186-
wandb_params.update({k[len("wandb-"):]: v})
187-
self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
183+
self.wandb_logger = WandbLogger.initialize_wandb_logger(
184+
self.args,
185+
self.exp,
186+
self.evaluator.dataloader.dataset
187+
)
188188
else:
189189
raise ValueError("logger must be either 'tensorboard' or 'wandb'")
190190

@@ -263,8 +263,11 @@ def after_iter(self):
263263

264264
if self.rank == 0:
265265
if self.args.logger == "wandb":
266-
self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
267-
self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})
266+
metrics = {"train/" + k: v.latest for k, v in loss_meter.items()}
267+
metrics.update({
268+
"train/lr": self.meter["lr"].latest
269+
})
270+
self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter)
268271

269272
self.meter.clear_meters()
270273

@@ -322,8 +325,8 @@ def evaluate_and_save_model(self):
322325
evalmodel = evalmodel.module
323326

324327
with adjust_status(evalmodel, training=False):
325-
ap50_95, ap50, summary = self.exp.eval(
326-
evalmodel, self.evaluator, self.is_distributed
328+
(ap50_95, ap50, summary), predictions = self.exp.eval(
329+
evalmodel, self.evaluator, self.is_distributed, return_outputs=True
327330
)
328331

329332
update_best_ckpt = ap50_95 > self.best_ap
@@ -337,16 +340,17 @@ def evaluate_and_save_model(self):
337340
self.wandb_logger.log_metrics({
338341
"val/COCOAP50": ap50,
339342
"val/COCOAP50_95": ap50_95,
340-
"epoch": self.epoch + 1,
343+
"train/epoch": self.epoch + 1,
341344
})
345+
self.wandb_logger.log_images(predictions)
342346
logger.info("\n" + summary)
343347
synchronize()
344348

345-
self.save_ckpt("last_epoch", update_best_ckpt)
349+
self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95)
346350
if self.save_history_ckpt:
347-
self.save_ckpt(f"epoch_{self.epoch + 1}")
351+
self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95)
348352

349-
def save_ckpt(self, ckpt_name, update_best_ckpt=False):
353+
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
350354
if self.rank == 0:
351355
save_model = self.ema_model.ema if self.use_model_ema else self.model
352356
logger.info("Save weights to {}".format(self.file_name))
@@ -355,6 +359,7 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
355359
"model": save_model.state_dict(),
356360
"optimizer": self.optimizer.state_dict(),
357361
"best_ap": self.best_ap,
362+
"curr_ap": ap,
358363
}
359364
save_checkpoint(
360365
ckpt_state,
@@ -364,4 +369,14 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
364369
)
365370

366371
if self.args.logger == "wandb":
367-
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
372+
self.wandb_logger.save_checkpoint(
373+
self.file_name,
374+
ckpt_name,
375+
update_best_ckpt,
376+
metadata={
377+
"epoch": self.epoch + 1,
378+
"optimizer": self.optimizer.state_dict(),
379+
"best_ap": self.best_ap,
380+
"curr_ap": ap
381+
}
382+
)

yolox/data/datasets/coco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(
6565
remove_useless_info(self.coco)
6666
self.ids = self.coco.getImgIds()
6767
self.class_ids = sorted(self.coco.getCatIds())
68-
cats = self.coco.loadCats(self.coco.getCatIds())
69-
self._classes = tuple([c["name"] for c in cats])
68+
self.cats = self.coco.loadCats(self.coco.getCatIds())
69+
self._classes = tuple([c["name"] for c in self.cats])
7070
self.imgs = None
7171
self.name = name
7272
self.img_size = img_size

yolox/evaluators/coco_evaluator.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import tempfile
1010
import time
11+
from collections import ChainMap, defaultdict
1112
from loguru import logger
1213
from tabulate import tabulate
1314
from tqdm import tqdm
@@ -120,6 +121,7 @@ def evaluate(
120121
trt_file=None,
121122
decoder=None,
122123
test_size=None,
124+
return_outputs=False
123125
):
124126
"""
125127
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
@@ -142,6 +144,7 @@ def evaluate(
142144
model = model.half()
143145
ids = []
144146
data_list = []
147+
output_data = defaultdict()
145148
progress_bar = tqdm if is_main_process() else iter
146149

147150
inference_time = 0
@@ -184,20 +187,29 @@ def evaluate(
184187
nms_end = time_synchronized()
185188
nms_time += nms_end - infer_end
186189

187-
data_list.extend(self.convert_to_coco_format(outputs, info_imgs, ids))
190+
data_list_elem, image_wise_data = self.convert_to_coco_format(
191+
outputs, info_imgs, ids, return_outputs=True)
192+
data_list.extend(data_list_elem)
193+
output_data.update(image_wise_data)
188194

189195
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
190196
if distributed:
191197
data_list = gather(data_list, dst=0)
198+
output_data = gather(output_data, dst=0)
192199
data_list = list(itertools.chain(*data_list))
200+
output_data = dict(ChainMap(*output_data))
193201
torch.distributed.reduce(statistics, dst=0)
194202

195203
eval_results = self.evaluate_prediction(data_list, statistics)
196204
synchronize()
205+
206+
if return_outputs:
207+
return eval_results, output_data
197208
return eval_results
198209

199-
def convert_to_coco_format(self, outputs, info_imgs, ids):
210+
def convert_to_coco_format(self, outputs, info_imgs, ids, return_outputs=False):
200211
data_list = []
212+
image_wise_data = defaultdict(dict)
201213
for (output, img_h, img_w, img_id) in zip(
202214
outputs, info_imgs[0], info_imgs[1], ids
203215
):
@@ -212,10 +224,22 @@ def convert_to_coco_format(self, outputs, info_imgs, ids):
212224
self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
213225
)
214226
bboxes /= scale
215-
bboxes = xyxy2xywh(bboxes)
216-
217227
cls = output[:, 6]
218228
scores = output[:, 4] * output[:, 5]
229+
230+
image_wise_data.update({
231+
int(img_id): {
232+
"bboxes": [box.numpy().tolist() for box in bboxes],
233+
"scores": [score.numpy().item() for score in scores],
234+
"categories": [
235+
self.dataloader.dataset.class_ids[int(cls[ind])]
236+
for ind in range(bboxes.shape[0])
237+
],
238+
}
239+
})
240+
241+
bboxes = xyxy2xywh(bboxes)
242+
219243
for ind in range(bboxes.shape[0]):
220244
label = self.dataloader.dataset.class_ids[int(cls[ind])]
221245
pred_data = {
@@ -226,6 +250,9 @@ def convert_to_coco_format(self, outputs, info_imgs, ids):
226250
"segmentation": [],
227251
} # COCO json format
228252
data_list.append(pred_data)
253+
254+
if return_outputs:
255+
return data_list, image_wise_data
229256
return data_list
230257

231258
def evaluate_prediction(self, data_dict, statistics):

yolox/exp/yolox_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,5 +318,5 @@ def get_trainer(self, args):
318318
# NOTE: trainer shouldn't be an attribute of exp object
319319
return trainer
320320

321-
def eval(self, model, evaluator, is_distributed, half=False):
322-
return evaluator.evaluate(model, is_distributed, half)
321+
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
322+
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)

0 commit comments

Comments
 (0)