@@ -102,8 +102,12 @@ be found in
102
102
os.environ[' MASTER_ADDR' ] = ' localhost'
103
103
os.environ[' MASTER_PORT' ] = ' 12355'
104
104
105
+ # We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
106
+ # such as CUDA, MPS, MTIA, or XPU.
107
+ acc = torch.accelerator.current_accelerator()
108
+ backend = torch.distributed.get_default_backend_for_device(acc)
105
109
# initialize the process group
106
- dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
110
+ dist.init_process_group(backend , rank = rank, world_size = world_size)
107
111
108
112
def cleanup ():
109
113
dist.destroy_process_group()
@@ -216,8 +220,11 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
216
220
# Use a barrier() to make sure that process 1 loads the model after process
217
221
# 0 saves it.
218
222
dist.barrier()
223
+ # We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
224
+ # such as CUDA, MPS, MTIA, or XPU.
225
+ acc = torch.accelerator.current_accelerator()
219
226
# configure map_location properly
220
- map_location = {' cuda: %d ' % 0 : ' cuda: %d ' % rank}
227
+ map_location = {f ' { acc } :0 ' : f ' { acc } : { rank} ' }
221
228
ddp_model.load_state_dict(
222
229
torch.load(CHECKPOINT_PATH , map_location = map_location, weights_only = True ))
223
230
@@ -295,7 +302,7 @@ either the application or the model ``forward()`` method.
295
302
296
303
297
304
if __name__ == " __main__" :
298
- n_gpus = torch.cuda .device_count()
305
+ n_gpus = torch.accelerator .device_count()
299
306
assert n_gpus >= 2 , f " Requires at least 2 GPUs to run, but got { n_gpus} "
300
307
world_size = n_gpus
301
308
run_demo(demo_basic, world_size)
@@ -331,12 +338,14 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
331
338
332
339
333
340
def demo_basic ():
334
- torch.cuda.set_device(int (os.environ[" LOCAL_RANK" ]))
335
- dist.init_process_group(" nccl" )
341
+ torch.accelerator.set_device_index(int (os.environ[" LOCAL_RANK" ]))
342
+ acc = torch.accelerator.current_accelerator()
343
+ backend = torch.distributed.get_default_backend_for_device(acc)
344
+ dist.init_process_group(backend)
336
345
rank = dist.get_rank()
337
346
print (f " Start running basic DDP example on rank { rank} . " )
338
347
# create model and move it to GPU with id rank
339
- device_id = rank % torch.cuda .device_count()
348
+ device_id = rank % torch.accelerator .device_count()
340
349
model = ToyModel().to(device_id)
341
350
ddp_model = DDP(model, device_ids = [device_id])
342
351
loss_fn = nn.MSELoss()
0 commit comments