Skip to content

Commit bb9185c

Browse files
authored
fix(utils,exp): logger compat issue and exp check (Megvii-BaseDetection#1618)
fix(utils,exp): logger compat issue and exp check (Megvii-BaseDetection#1618)
1 parent 618fd8c commit bb9185c

File tree

5 files changed

+19
-10
lines changed

5 files changed

+19
-10
lines changed

tools/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.backends.cudnn as cudnn
1212

1313
from yolox.core import launch
14-
from yolox.exp import Exp, get_exp
14+
from yolox.exp import Exp, check_exp_value, get_exp
1515
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
1616

1717

@@ -123,6 +123,7 @@ def main(exp: Exp, args):
123123
args = make_parser().parse_args()
124124
exp = get_exp(args.exp_file, args.name)
125125
exp.merge(args.opts)
126+
check_exp_value(exp)
126127

127128
if not args.experiment_name:
128129
args.experiment_name = exp.exp_name

yolox/exp/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
from .base_exp import BaseExp
65
from .build import get_exp
7-
from .yolox_base import Exp
6+
from .yolox_base import Exp, check_exp_value

yolox/exp/base_exp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
import ast
@@ -66,15 +65,15 @@ def __repr__(self):
6665
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
6766

6867
def merge(self, cfg_list):
69-
assert len(cfg_list) % 2 == 0
68+
assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}"
7069
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
7170
# only update value with same key
7271
if hasattr(self, k):
7372
src_value = getattr(self, k)
7473
src_type = type(src_value)
7574

7675
# pre-process input if source type is list or tuple
77-
if isinstance(src_value, List) or isinstance(src_value, Tuple):
76+
if isinstance(src_value, (List, Tuple)):
7877
v = v.strip("[]()")
7978
v = [t.strip() for t in v.split(",")]
8079

yolox/exp/yolox_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
import os
@@ -11,6 +10,8 @@
1110

1211
from .base_exp import BaseExp
1312

13+
__all__ = ["Exp", "check_exp_value"]
14+
1415

1516
class Exp(BaseExp):
1617
def __init__(self):
@@ -350,3 +351,8 @@ def get_trainer(self, args):
350351

351352
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
352353
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
354+
355+
356+
def check_exp_value(exp: Exp):
357+
h, w = exp.input_size
358+
assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32"

yolox/utils/logger.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding:utf-8 -*-
32
# Copyright (c) Megvii Inc. All rights reserved.
43

54
import inspect
@@ -58,15 +57,20 @@ def write(self, buf):
5857
sys.__stdout__.write(buf)
5958

6059
def flush(self):
61-
pass
60+
# flush is related with CPR(cursor position report) in terminal
61+
return sys.__stdout__.flush()
6262

6363
def isatty(self):
6464
# when using colab, jax is installed by default and issue like
6565
# https://github.com/Megvii-BaseDetection/YOLOX/issues/1437 might be raised
6666
# due to missing attribute like`isatty`.
6767
# For more details, checked the following link:
6868
# https://github.com/google/jax/blob/10720258ea7fb5bde997dfa2f3f71135ab7a6733/jax/_src/pretty_printer.py#L54 # noqa
69-
return True
69+
return sys.__stdout__.isatty()
70+
71+
def fileno(self):
72+
# To solve the issue when using debug tools like pdb
73+
return sys.__stdout__.fileno()
7074

7175

7276
def redirect_sys_output(log_level="INFO"):

0 commit comments

Comments
 (0)