Skip to content

Commit 50f9129

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Fix state dict for partial rowwise adam optimizer (#3267)
Summary: Pull Request resolved: #3267 Partial rowwise adam optimizer will add a new fqn into state dict and ckpt. E.g. momentum1 and momentum2. This diff is to support kvzch saving two momentum into state dict. Reviewed By: emlin Differential Revision: D79622295 fbshipit-source-id: b9d8af448aa4d1b923eda31f895c70251668ca98
1 parent 9bcde1b commit 50f9129

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import torch
3131
import torch.distributed as dist
32+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
3233
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3334
BackendType,
3435
EvictionPolicy,
@@ -432,29 +433,25 @@ def __init__( # noqa C901
432433
all_optimizer_states = emb_module.get_optimizer_state(
433434
sorted_id_tensor=sorted_id_tensors,
434435
)
435-
opt_param_list = [param["momentum1"] for param in all_optimizer_states]
436+
436437
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
437438
for emb_table in emb_table_config_copy:
438439
emb_table.local_metadata.placement._device = torch.device("cpu")
439-
opt_sharded_t_list = create_virtual_sharded_tensors(
440-
emb_table_config_copy, opt_param_list, self._pg
441-
)
442440

443-
for (
444-
emb_config,
445-
sharded_weight,
446-
opt_sharded_t,
447-
) in zip(
441+
for emb_config, sharded_weight, opt_state in zip(
448442
emb_table_config_copy,
449443
sharded_embedding_weights_by_table,
450-
opt_sharded_t_list,
444+
all_optimizer_states,
451445
):
452446
param_key = emb_config.name + ".weight"
453447
state[sharded_weight] = {}
454448
param_group["params"].append(sharded_weight)
455449
params[param_key] = sharded_weight
456-
457-
state[sharded_weight][f"{emb_config.name}.momentum1"] = opt_sharded_t
450+
for key, value in opt_state.items():
451+
opt_sharded_t = create_virtual_sharded_tensors(
452+
[emb_config], [value], self._pg
453+
)[0]
454+
state[sharded_weight][f"{emb_config.name}.{key}"] = opt_sharded_t
458455

459456
super().__init__(params, state, [param_group])
460457

0 commit comments

Comments
 (0)