Skip to content

Commit c6aa90f

Browse files
authored
Fix spmd sharding visualization when device index is >= 10 (#9475)
1 parent b612613 commit c6aa90f

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

torch_xla/distributed/spmd/debugging.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,26 @@ def visualize_sharding(sharding: str,
5757
# eg: '{devices=[2,2]0,1,2,3}'
5858
# eg: '{replicated}'
5959
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60+
print(f"Visualizing {sharding} (showing up to the first two dimensions)")
6061
if sharding == '{replicated}' or len(sharding) == 0:
6162
heights = 1
6263
widths = 1
6364
num_devices = xr.global_runtime_device_count()
6465
device_ids = list(range(num_devices))
6566
slices.setdefault((0, 0), device_ids)
6667
else:
67-
sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1]
68+
sharding_spec = sharding[sharding.index('[') +
69+
1:sharding.index(']')].split(",")
6870
device_list_original = sharding.split(' last_tile_dim_replicate')
6971
if len(device_list_original) == 2 and device_list_original[1] == '}':
7072
try:
7173
device_list_original_first = device_list_original[0]
7274
device_list = device_list_original_first[device_list_original_first.
7375
index(']') + 1:]
7476
device_indices_map = [int(s) for s in device_list.split(',')]
75-
heights = int(sharding_spac[1])
76-
widths = int(sharding_spac[3])
77-
last_dim_depth = int(sharding_spac[5])
77+
heights = int(sharding_spec[0])
78+
widths = int(sharding_spec[1])
79+
last_dim_depth = int(sharding_spec[-1])
7880
devices_len = len(device_indices_map)
7981
len_after_dim_down = devices_len // last_dim_depth
8082
for i in range(len_after_dim_down):
@@ -96,8 +98,8 @@ def visualize_sharding(sharding: str,
9698
device_list = device_list_original_first[device_list_original_first.
9799
index(']') + 1:-1]
98100
device_indices_map = [int(i) for i in device_list.split(',')]
99-
heights = int(sharding_spac[1])
100-
widths = int(sharding_spac[3])
101+
heights = int(sharding_spec[0])
102+
widths = int(sharding_spec[1])
101103
devices_len = len(device_indices_map)
102104
for i in range(devices_len):
103105
slices.setdefault((i // widths, i % widths), device_indices_map[i])

0 commit comments

Comments
 (0)