Skip to content

Commit e52ccb3

Browse files
shinkpytorchmergebot
authored andcommitted
[Device] Replace hardcoded devices with 'torch._C._get_accelerator()' (pytorch#139032)
I noticed that some hard-code like `"cuda" if torch.cuda.is_available() else "cpu"` which can be replaced with `torch._C._get_accelerator()` Pull Request resolved: pytorch#139032 Approved by: https://github.com/ezyang
1 parent a0865b0 commit e52ccb3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch/distributed/_composable/fsdp/_fsdp_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def _init_default_fully_shard_mesh() -> DeviceMesh:
5858
if not dist.distributed_c10d.is_initialized():
5959
dist.distributed_c10d.init_process_group()
6060
default_pg = dist.distributed_c10d._get_default_group()
61-
device_type = "cuda" if torch.cuda.is_available() else "cpu"
62-
mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),))
61+
device = torch._C._get_accelerator()
62+
mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),))
6363
return mesh
6464

6565

0 commit comments

Comments
 (0)