@@ -167,8 +167,14 @@ def shards_all_to_all(
167
167
extend_shard_name (shard_name )
168
168
][tmp_momentum_extender (shard_name )].local_shards ()
169
169
assert len (local_optimizer ) == 1
170
+ local_optimizer_tensor = local_optimizer [0 ].tensor
171
+ if len (local_optimizer_tensor .size ()) == 1 : # 1D Optimizer Tensor
172
+ # Convert to 2D Tensor, transpose, for AllToAll
173
+ local_optimizer_tensor = local_optimizer_tensor .view (
174
+ local_optimizer_tensor .size (0 ), 1
175
+ )
170
176
padded_local_optimizer = pad_tensor_to_max_dims (
171
- local_optimizer [ 0 ]. tensor , max_dim_0 , max_dim_1
177
+ local_optimizer_tensor , max_dim_0 , max_dim_1
172
178
)
173
179
local_table_to_opt_by_dst_rank [dst_rank ].append (
174
180
padded_local_optimizer
@@ -284,9 +290,7 @@ def update_state_dict_post_resharding(
284
290
for shard_name , shard_size in ordered_shard_names_and_lengths :
285
291
end_slice_index = slice_index + max_dim_0
286
292
cur_t = output_tensor [slice_index :end_slice_index ]
287
- cur_t = pad_tensor_to_max_dims (
288
- cur_t , shard_size [0 ], shard_size [1 ], remove_padding = True
289
- )
293
+ cur_t = pad_tensor_to_max_dims (cur_t , shard_size [0 ], shard_size [1 ])
290
294
shard_name_to_local_output_tensor [shard_name ] = cur_t
291
295
slice_index = end_slice_index
292
296
@@ -335,9 +339,7 @@ def update_optimizer_state_post_resharding(
335
339
for shard_name , shard_size in ordered_shard_names_and_lengths :
336
340
end_slice_index = slice_index + max_dim_0
337
341
cur_t = output_tensor [slice_index :end_slice_index ]
338
- cur_t = pad_tensor_to_max_dims (
339
- cur_t , shard_size [0 ], shard_size [1 ], remove_padding = True
340
- )
342
+ cur_t = pad_tensor_to_max_dims (cur_t , shard_size [0 ], shard_size [1 ])
341
343
shard_name_to_local_output_tensor [shard_name ] = cur_t
342
344
slice_index = end_slice_index
343
345
@@ -352,9 +354,13 @@ def update_optimizer_state_post_resharding(
352
354
sharded_t = item [momentum_name ]
353
355
assert len (sharded_t ._local_shards ) == 1
354
356
# TODO: support multiple shards in CW sharding
357
+ local_tensor = shard_name_to_local_output_tensor [shard_name ]
358
+ if len (sharded_t ._local_shards [0 ].tensor .size ()) == 1 :
359
+ # Need to transpose 1D optimizer tensor, due to previous conversion
360
+ local_tensor = local_tensor .T [0 ]
355
361
sharded_t ._local_shards = [
356
362
Shard (
357
- tensor = shard_name_to_local_output_tensor [ shard_name ] ,
363
+ tensor = local_tensor ,
358
364
metadata = shard .metadata ,
359
365
)
360
366
for shard in sharded_t ._local_shards
@@ -426,7 +432,6 @@ def pad_tensor_to_max_dims(
426
432
t : torch .Tensor ,
427
433
expected_dim_0 : int ,
428
434
expected_dim_1 : int ,
429
- remove_padding : bool = False ,
430
435
) -> torch .Tensor :
431
436
"""
432
437
Pads a tensor on the right and bottom with zeros.
@@ -441,14 +446,10 @@ def pad_tensor_to_max_dims(
441
446
"""
442
447
pad_right = expected_dim_1 - t .size (1 )
443
448
pad_bottom = expected_dim_0 - t .size (0 )
449
+ pad = (0 , pad_right , 0 , pad_bottom )
444
450
return F .pad (
445
451
input = t ,
446
- pad = (
447
- 0 ,
448
- pad_right ,
449
- 0 ,
450
- pad_bottom ,
451
- ), # right and bottom
452
+ pad = pad ,
452
453
mode = "constant" ,
453
454
value = 0 ,
454
455
)
0 commit comments