@@ -57,24 +57,26 @@ def visualize_sharding(sharding: str,
57
57
# eg: '{devices=[2,2]0,1,2,3}'
58
58
# eg: '{replicated}'
59
59
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60
+ print (f"Visualizing { sharding } (showing up to the first two dimensions)" )
60
61
if sharding == '{replicated}' or len (sharding ) == 0 :
61
62
heights = 1
62
63
widths = 1
63
64
num_devices = xr .global_runtime_device_count ()
64
65
device_ids = list (range (num_devices ))
65
66
slices .setdefault ((0 , 0 ), device_ids )
66
67
else :
67
- sharding_spac = sharding [sharding .index ('[' ):sharding .index (']' ) + 1 ]
68
+ sharding_spec = sharding [sharding .index ('[' ) +
69
+ 1 :sharding .index (']' )].split ("," )
68
70
device_list_original = sharding .split (' last_tile_dim_replicate' )
69
71
if len (device_list_original ) == 2 and device_list_original [1 ] == '}' :
70
72
try :
71
73
device_list_original_first = device_list_original [0 ]
72
74
device_list = device_list_original_first [device_list_original_first .
73
75
index (']' ) + 1 :]
74
76
device_indices_map = [int (s ) for s in device_list .split (',' )]
75
- heights = int (sharding_spac [ 1 ])
76
- widths = int (sharding_spac [ 3 ])
77
- last_dim_depth = int (sharding_spac [ 5 ])
77
+ heights = int (sharding_spec [ 0 ])
78
+ widths = int (sharding_spec [ 1 ])
79
+ last_dim_depth = int (sharding_spec [ - 1 ])
78
80
devices_len = len (device_indices_map )
79
81
len_after_dim_down = devices_len // last_dim_depth
80
82
for i in range (len_after_dim_down ):
@@ -96,8 +98,8 @@ def visualize_sharding(sharding: str,
96
98
device_list = device_list_original_first [device_list_original_first .
97
99
index (']' ) + 1 :- 1 ]
98
100
device_indices_map = [int (i ) for i in device_list .split (',' )]
99
- heights = int (sharding_spac [ 1 ])
100
- widths = int (sharding_spac [ 3 ])
101
+ heights = int (sharding_spec [ 0 ])
102
+ widths = int (sharding_spec [ 1 ])
101
103
devices_len = len (device_indices_map )
102
104
for i in range (devices_len ):
103
105
slices .setdefault ((i // widths , i % widths ), device_indices_map [i ])
0 commit comments