|
111 | 111 | # Optimizer parameters |
112 | 112 | parser.add_argument('--opt', default='momentum', type=str, metavar='OPTIMIZER', |
113 | 113 | help='Optimizer (default: "momentum"') |
114 | | -parser.add_argument('--opt-eps', default=1e-3, type=float, metavar='EPSILON', |
115 | | - help='Optimizer Epsilon (default: 1e-3)') |
| 114 | +parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', |
| 115 | + help='Optimizer Epsilon (default: None, optimizer default)') |
116 | 116 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
117 | 117 | help='SGD momentum (default: 0.9)') |
118 | 118 | parser.add_argument('--weight-decay', type=float, default=4e-5, |
@@ -330,14 +330,12 @@ def main(): |
330 | 330 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') |
331 | 331 |
|
332 | 332 | if args.torchscript: |
| 333 | + assert not args.torchcompile, 'Cannot use torch.compile() with torch.jit.script()' |
333 | 334 | assert not use_amp == 'apex', \ |
334 | 335 | 'Cannot use APEX AMP with torchscripted model, force native amp with `--native-amp` flag' |
335 | 336 | assert not args.sync_bn, \ |
336 | 337 | 'Cannot use SyncBatchNorm with torchscripted model. Use `--dist-bn reduce` instead of `--sync-bn`' |
337 | 338 | model = torch.jit.script(model) |
338 | | - elif args.torchcompile: |
339 | | - # FIXME dynamo might need move below DDP wrapping? TBD |
340 | | - model = torch.compile(model, backend=args.torchcompile) |
341 | 339 |
|
342 | 340 | optimizer = create_optimizer(args, model) |
343 | 341 |
|
@@ -390,6 +388,9 @@ def main(): |
390 | 388 | # NOTE: ModelEma init could be moved after DDP wrapper if using PyTorch DDP, not Apex. |
391 | 389 | model_ema.set(model) |
392 | 390 |
|
| 391 | + if args.torchcompile: |
| 392 | + model = torch.compile(model, backend=args.torchcompile) |
| 393 | + |
393 | 394 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) |
394 | 395 | start_epoch = 0 |
395 | 396 | if args.start_epoch is not None: |
@@ -521,7 +522,10 @@ def create_datasets_and_loaders( |
521 | 522 | labeler = None |
522 | 523 | if not args.bench_labeler: |
523 | 524 | labeler = AnchorLabeler( |
524 | | - Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5) |
| 525 | + Anchors.from_config(model_config), |
| 526 | + model_config.num_classes, |
| 527 | + match_threshold=0.5, |
| 528 | + ) |
525 | 529 |
|
526 | 530 | loader_train = create_loader( |
527 | 531 | dataset_train, |
|
0 commit comments