|
29 | 29 |
|
30 | 30 | import torch
|
31 | 31 | import torch.distributed as dist
|
| 32 | +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType |
32 | 33 | from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
33 | 34 | BackendType,
|
34 | 35 | EvictionPolicy,
|
@@ -432,29 +433,25 @@ def __init__( # noqa C901
|
432 | 433 | all_optimizer_states = emb_module.get_optimizer_state(
|
433 | 434 | sorted_id_tensor=sorted_id_tensors,
|
434 | 435 | )
|
435 |
| - opt_param_list = [param["momentum1"] for param in all_optimizer_states] |
| 436 | + |
436 | 437 | emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
|
437 | 438 | for emb_table in emb_table_config_copy:
|
438 | 439 | emb_table.local_metadata.placement._device = torch.device("cpu")
|
439 |
| - opt_sharded_t_list = create_virtual_sharded_tensors( |
440 |
| - emb_table_config_copy, opt_param_list, self._pg |
441 |
| - ) |
442 | 440 |
|
443 |
| - for ( |
444 |
| - emb_config, |
445 |
| - sharded_weight, |
446 |
| - opt_sharded_t, |
447 |
| - ) in zip( |
| 441 | + for emb_config, sharded_weight, opt_state in zip( |
448 | 442 | emb_table_config_copy,
|
449 | 443 | sharded_embedding_weights_by_table,
|
450 |
| - opt_sharded_t_list, |
| 444 | + all_optimizer_states, |
451 | 445 | ):
|
452 | 446 | param_key = emb_config.name + ".weight"
|
453 | 447 | state[sharded_weight] = {}
|
454 | 448 | param_group["params"].append(sharded_weight)
|
455 | 449 | params[param_key] = sharded_weight
|
456 |
| - |
457 |
| - state[sharded_weight][f"{emb_config.name}.momentum1"] = opt_sharded_t |
| 450 | + for key, value in opt_state.items(): |
| 451 | + opt_sharded_t = create_virtual_sharded_tensors( |
| 452 | + [emb_config], [value], self._pg |
| 453 | + )[0] |
| 454 | + state[sharded_weight][f"{emb_config.name}.{key}"] = opt_sharded_t |
458 | 455 |
|
459 | 456 | super().__init__(params, state, [param_group])
|
460 | 457 |
|
|
0 commit comments