Skip to content

Commit d43c9e3

Browse files
committed
Switch to torchrun for distributed launch, use timm init_distributed_device helper for cluster training support
1 parent 94136af commit d43c9e3

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

distributed_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
22
NUM_PROC=$1
33
shift
4-
python -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@"
4+
torchrun --nproc-per-node=$NUM_PROC train.py "$@"
55

train.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@
216216
help='Best metric (default: "map"')
217217
parser.add_argument('--tta', type=int, default=0, metavar='N',
218218
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
219-
parser.add_argument("--local_rank", default=0, type=int)
220219

221220

222221
def _parse_args():
@@ -256,14 +255,8 @@ def main():
256255
args.device = 'cuda:0'
257256
args.world_size = 1
258257
args.rank = 0 # global rank
259-
if args.distributed:
260-
args.device = 'cuda:%d' % args.local_rank
261-
torch.cuda.set_device(args.local_rank)
262-
torch.distributed.init_process_group(backend='nccl', init_method='env://')
263-
args.world_size = torch.distributed.get_world_size()
264-
args.rank = torch.distributed.get_rank()
258+
device = utils.init_distributed_device(args)
265259
assert args.rank >= 0
266-
267260
if args.distributed:
268261
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
269262
% (args.rank, args.world_size))

0 commit comments

Comments
 (0)