Skip to content

Commit 3209ebb

Browse files
committed
Use opt default eps (arg defaults to None), torch.compile after DDP
1 parent d562154 commit 3209ebb

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@
111111
# Optimizer parameters
112112
parser.add_argument('--opt', default='momentum', type=str, metavar='OPTIMIZER',
113113
help='Optimizer (default: "momentum"')
114-
parser.add_argument('--opt-eps', default=1e-3, type=float, metavar='EPSILON',
115-
help='Optimizer Epsilon (default: 1e-3)')
114+
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
115+
help='Optimizer Epsilon (default: None, optimizer default)')
116116
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
117117
help='SGD momentum (default: 0.9)')
118118
parser.add_argument('--weight-decay', type=float, default=4e-5,
@@ -330,14 +330,12 @@ def main():
330330
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
331331

332332
if args.torchscript:
333+
assert not args.torchcompile, 'Cannot use torch.compile() with torch.jit.script()'
333334
assert not use_amp == 'apex', \
334335
'Cannot use APEX AMP with torchscripted model, force native amp with `--native-amp` flag'
335336
assert not args.sync_bn, \
336337
'Cannot use SyncBatchNorm with torchscripted model. Use `--dist-bn reduce` instead of `--sync-bn`'
337338
model = torch.jit.script(model)
338-
elif args.torchcompile:
339-
# FIXME dynamo might need move below DDP wrapping? TBD
340-
model = torch.compile(model, backend=args.torchcompile)
341339

342340
optimizer = create_optimizer(args, model)
343341

@@ -390,6 +388,9 @@ def main():
390388
# NOTE: ModelEma init could be moved after DDP wrapper if using PyTorch DDP, not Apex.
391389
model_ema.set(model)
392390

391+
if args.torchcompile:
392+
model = torch.compile(model, backend=args.torchcompile)
393+
393394
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
394395
start_epoch = 0
395396
if args.start_epoch is not None:
@@ -521,7 +522,10 @@ def create_datasets_and_loaders(
521522
labeler = None
522523
if not args.bench_labeler:
523524
labeler = AnchorLabeler(
524-
Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5)
525+
Anchors.from_config(model_config),
526+
model_config.num_classes,
527+
match_threshold=0.5,
528+
)
525529

526530
loader_train = create_loader(
527531
dataset_train,

0 commit comments

Comments
 (0)