Skip to content

Commit 533f82b

Browse files
aporialiaofacebook-github-bot
authored andcommitted
RowWiseAdaGrad Case - Handle 1D optimizer tensors (#3159)
Summary: Pull Request resolved: #3159 Experimenting Reshard API on trainers, which uses RowWiseAdaGrad. Needed to support 1D Tensor for resharding Reviewed By: iamzainhuda Differential Revision: D77692894 fbshipit-source-id: 3718868cd5dbe863f93e20b84bad7312ce0ea59e
1 parent 5a73ae5 commit 533f82b

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,14 @@ def shards_all_to_all(
167167
extend_shard_name(shard_name)
168168
][tmp_momentum_extender(shard_name)].local_shards()
169169
assert len(local_optimizer) == 1
170+
local_optimizer_tensor = local_optimizer[0].tensor
171+
if len(local_optimizer_tensor.size()) == 1: # 1D Optimizer Tensor
172+
# Convert to 2D Tensor, transpose, for AllToAll
173+
local_optimizer_tensor = local_optimizer_tensor.view(
174+
local_optimizer_tensor.size(0), 1
175+
)
170176
padded_local_optimizer = pad_tensor_to_max_dims(
171-
local_optimizer[0].tensor, max_dim_0, max_dim_1
177+
local_optimizer_tensor, max_dim_0, max_dim_1
172178
)
173179
local_table_to_opt_by_dst_rank[dst_rank].append(
174180
padded_local_optimizer
@@ -284,9 +290,7 @@ def update_state_dict_post_resharding(
284290
for shard_name, shard_size in ordered_shard_names_and_lengths:
285291
end_slice_index = slice_index + max_dim_0
286292
cur_t = output_tensor[slice_index:end_slice_index]
287-
cur_t = pad_tensor_to_max_dims(
288-
cur_t, shard_size[0], shard_size[1], remove_padding=True
289-
)
293+
cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1])
290294
shard_name_to_local_output_tensor[shard_name] = cur_t
291295
slice_index = end_slice_index
292296

@@ -335,9 +339,7 @@ def update_optimizer_state_post_resharding(
335339
for shard_name, shard_size in ordered_shard_names_and_lengths:
336340
end_slice_index = slice_index + max_dim_0
337341
cur_t = output_tensor[slice_index:end_slice_index]
338-
cur_t = pad_tensor_to_max_dims(
339-
cur_t, shard_size[0], shard_size[1], remove_padding=True
340-
)
342+
cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1])
341343
shard_name_to_local_output_tensor[shard_name] = cur_t
342344
slice_index = end_slice_index
343345

@@ -352,9 +354,13 @@ def update_optimizer_state_post_resharding(
352354
sharded_t = item[momentum_name]
353355
assert len(sharded_t._local_shards) == 1
354356
# TODO: support multiple shards in CW sharding
357+
local_tensor = shard_name_to_local_output_tensor[shard_name]
358+
if len(sharded_t._local_shards[0].tensor.size()) == 1:
359+
# Need to transpose 1D optimizer tensor, due to previous conversion
360+
local_tensor = local_tensor.T[0]
355361
sharded_t._local_shards = [
356362
Shard(
357-
tensor=shard_name_to_local_output_tensor[shard_name],
363+
tensor=local_tensor,
358364
metadata=shard.metadata,
359365
)
360366
for shard in sharded_t._local_shards
@@ -426,7 +432,6 @@ def pad_tensor_to_max_dims(
426432
t: torch.Tensor,
427433
expected_dim_0: int,
428434
expected_dim_1: int,
429-
remove_padding: bool = False,
430435
) -> torch.Tensor:
431436
"""
432437
Pads a tensor on the right and bottom with zeros.
@@ -441,14 +446,10 @@ def pad_tensor_to_max_dims(
441446
"""
442447
pad_right = expected_dim_1 - t.size(1)
443448
pad_bottom = expected_dim_0 - t.size(0)
449+
pad = (0, pad_right, 0, pad_bottom)
444450
return F.pad(
445451
input=t,
446-
pad=(
447-
0,
448-
pad_right,
449-
0,
450-
pad_bottom,
451-
), # right and bottom
452+
pad=pad,
452453
mode="constant",
453454
value=0,
454455
)

torchrec/distributed/tests/test_dynamic_sharding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121

2222
from torch import nn, optim
2323

24-
from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor
24+
from torchrec import (
25+
distributed as trec_dist,
26+
EmbeddingBagCollection,
27+
KeyedJaggedTensor,
28+
optim as trec_optim,
29+
)
2530
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2631
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
2732
from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig
@@ -534,6 +539,12 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared):
534539
{
535540
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
536541
},
542+
{
543+
"embedding_bags": (
544+
trec_optim.RowWiseAdagrad,
545+
{"lr": 0.01},
546+
),
547+
},
537548
]
538549
),
539550
variable_batch_size=st.sampled_from(

0 commit comments

Comments
 (0)