Skip to content

Commit e9c6625

Browse files
committed
Add Acelerator API to DDP tutorial to support multiple accelerators and improve backend initialization
Signed-off-by: jafraustro <[email protected]>
1 parent ef98a6b commit e9c6625

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,12 @@ be found in
102102
os.environ['MASTER_ADDR'] = 'localhost'
103103
os.environ['MASTER_PORT'] = '12355'
104104
105+
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
106+
# such as CUDA, MPS, MTIA, or XPU.
107+
acc = torch.accelerator.current_accelerator()
108+
backend = torch.distributed.get_default_backend_for_device(acc)
105109
# initialize the process group
106-
dist.init_process_group("gloo", rank=rank, world_size=world_size)
110+
dist.init_process_group(backend, rank=rank, world_size=world_size)
107111
108112
def cleanup():
109113
dist.destroy_process_group()
@@ -216,8 +220,11 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
216220
# Use a barrier() to make sure that process 1 loads the model after process
217221
# 0 saves it.
218222
dist.barrier()
223+
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
224+
# such as CUDA, MPS, MTIA, or XPU.
225+
acc = torch.accelerator.current_accelerator()
219226
# configure map_location properly
220-
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
227+
map_location = {f'{acc}:0': f'{acc}:{rank}'}
221228
ddp_model.load_state_dict(
222229
torch.load(CHECKPOINT_PATH, map_location=map_location, weights_only=True))
223230
@@ -295,7 +302,7 @@ either the application or the model ``forward()`` method.
295302
296303
297304
if __name__ == "__main__":
298-
n_gpus = torch.cuda.device_count()
305+
n_gpus = torch.accelerator.device_count()
299306
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
300307
world_size = n_gpus
301308
run_demo(demo_basic, world_size)
@@ -331,12 +338,14 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
331338
332339
333340
def demo_basic():
334-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
335-
dist.init_process_group("nccl")
341+
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
342+
acc = torch.accelerator.current_accelerator()
343+
backend = torch.distributed.get_default_backend_for_device(acc)
344+
dist.init_process_group(backend)
336345
rank = dist.get_rank()
337346
print(f"Start running basic DDP example on rank {rank}.")
338347
# create model and move it to GPU with id rank
339-
device_id = rank % torch.cuda.device_count()
348+
device_id = rank % torch.accelerator.device_count()
340349
model = ToyModel().to(device_id)
341350
ddp_model = DDP(model, device_ids=[device_id])
342351
loss_fn = nn.MSELoss()

0 commit comments

Comments
 (0)