diff --git a/intermediate_source/ddp_tutorial.rst b/intermediate_source/ddp_tutorial.rst index 1f7221680b1..c63321ad14c 100644 --- a/intermediate_source/ddp_tutorial.rst +++ b/intermediate_source/ddp_tutorial.rst @@ -102,8 +102,12 @@ be found in os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' + # We want to be able to train our model on an `accelerator `__ + # such as CUDA, MPS, MTIA, or XPU. + acc = torch.accelerator.current_accelerator() + backend = torch.distributed.get_default_backend_for_device(acc) # initialize the process group - dist.init_process_group("gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend, rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() @@ -216,8 +220,11 @@ and elasticity support, please refer to `TorchElastic `__ + # such as CUDA, MPS, MTIA, or XPU. + acc = torch.accelerator.current_accelerator() # configure map_location properly - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + map_location = {f'{acc}:0': f'{acc}:{rank}'} ddp_model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=map_location, weights_only=True)) @@ -295,7 +302,7 @@ either the application or the model ``forward()`` method. if __name__ == "__main__": - n_gpus = torch.cuda.device_count() + n_gpus = torch.accelerator.device_count() assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" world_size = n_gpus run_demo(demo_basic, world_size) @@ -331,12 +338,14 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``. def demo_basic(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - dist.init_process_group("nccl") + torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"])) + acc = torch.accelerator.current_accelerator() + backend = torch.distributed.get_default_backend_for_device(acc) + dist.init_process_group(backend) rank = dist.get_rank() print(f"Start running basic DDP example on rank {rank}.") # create model and move it to GPU with id rank - device_id = rank % torch.cuda.device_count() + device_id = rank % torch.accelerator.device_count() model = ToyModel().to(device_id) ddp_model = DDP(model, device_ids=[device_id]) loss_fn = nn.MSELoss()