Skip to content

Commit f50d359

Browse files
ankurneogpytorchmergebot
authored andcommitted
[ c10d ] modify API to get device string from device with torch.device (pytorch#146290)
Modify the ```get_default_backend_for_device()``` API to extract the device string using ```torch.device()``` Pull Request resolved: pytorch#146290 Approved by: https://github.com/guangyey, https://github.com/H-Huang
1 parent 3a29992 commit f50d359

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/distributed/distributed_c10d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ def get_default_backend_for_device(device: Union[str, torch.device]) -> str:
13821382
if isinstance(device, torch.device):
13831383
device_str = device.type
13841384
else:
1385-
device_str = device.split(":")[0]
1385+
device_str = torch.device(device).type
13861386

13871387
backend = Backend.default_device_backend_map.get(device_str)
13881388
if backend is None:

0 commit comments

Comments
 (0)