Skip to content

Add util func for kvzch eviction mask (#1645) #3246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 65 additions & 17 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def _gen_named_parameters_by_table_ssd_pmt(
name as well as the parameter itself. The embedding table is in the form of
PartiallyMaterializedTensor to support windowed access.
"""
pmts, _, _ = emb_module.split_embedding_weights()
pmts, _, _, _ = emb_module.split_embedding_weights()
for table_config, pmt in zip(config.embedding_tables, pmts):
table_name = table_config.name
emb_table = pmt
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -1322,6 +1322,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand All @@ -1330,13 +1331,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
RocksDB snapshot to support windowed access.
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False)[0],
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor, None, None
yield key, tensor, None, None, None

def flush(self) -> None:
"""
Expand Down Expand Up @@ -1364,6 +1366,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
Expand Down Expand Up @@ -1415,6 +1418,7 @@ def __init__(
List[ShardedTensor],
List[ShardedTensor],
List[ShardedTensor],
List[ShardedTensor],
]
] = None

Expand Down Expand Up @@ -1490,7 +1494,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -1546,8 +1550,10 @@ def _init_sharded_split_embedding_weights(
if not force_regenerate and self._split_weights_res is not None:
return

pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
no_snapshot=False,
pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = (
self.split_embedding_weights(
no_snapshot=False,
)
)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
Expand Down Expand Up @@ -1581,17 +1587,31 @@ def _init_sharded_split_embedding_weights(
self._table_name_to_weight_count_per_rank,
use_param_size_as_rows=True,
)
# pyre-ignore
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
metadata_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy,
metadata_list, # pyre-ignore [6]
self._pg,
prefix,
self._table_name_to_weight_count_per_rank,
)

assert (
len(pmt_list)
== len(weight_ids_list) # pyre-ignore
== len(bucket_cnt_list) # pyre-ignore
== len(metadata_list) # pyre-ignore
)
assert (
len(pmt_sharded_t_list)
== len(weight_id_sharded_t_list)
== len(bucket_cnt_sharded_t_list)
== len(metadata_sharded_t_list)
)
self._split_weights_res = (
pmt_sharded_t_list,
weight_id_sharded_t_list,
bucket_cnt_sharded_t_list,
metadata_sharded_t_list,
)

def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
Expand All @@ -1600,6 +1620,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand All @@ -1608,6 +1629,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
optional ShardedTensor for weight_id
optional ShardedTensor for bucket_cnt
optional ShardedTensor for metadata
"""
self._init_sharded_split_embedding_weights()
# pyre-ignore[16]
Expand All @@ -1616,13 +1638,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
pmt_sharded_t_list = self._split_weights_res[0]
weight_id_sharded_t_list = self._split_weights_res[1]
bucket_cnt_sharded_t_list = self._split_weights_res[2]
metadata_sharded_t_list = self._split_weights_res[3]
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
table_config = self._config.embedding_tables[table_idx]
key = append_prefix(prefix, f"{table_config.name}")

yield key, pmt_sharded_t, weight_id_sharded_t_list[
table_idx
], bucket_cnt_sharded_t_list[table_idx]
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]

def flush(self) -> None:
"""
Expand Down Expand Up @@ -1651,6 +1674,7 @@ def split_embedding_weights(
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)

Expand Down Expand Up @@ -2079,7 +2103,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -2129,6 +2153,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand All @@ -2137,13 +2162,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
RocksDB snapshot to support windowed access.
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False)[0],
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor, None, None
yield key, tensor, None, None, None

def flush(self) -> None:
"""
Expand All @@ -2170,6 +2196,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
Expand Down Expand Up @@ -2223,6 +2250,7 @@ def __init__(
List[ShardedTensor],
List[ShardedTensor],
List[ShardedTensor],
List[ShardedTensor],
]
] = None

Expand Down Expand Up @@ -2298,7 +2326,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -2354,8 +2382,10 @@ def _init_sharded_split_embedding_weights(
if not force_regenerate and self._split_weights_res is not None:
return

pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
no_snapshot=False,
pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = (
self.split_embedding_weights(
no_snapshot=False,
)
)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
Expand Down Expand Up @@ -2389,17 +2419,31 @@ def _init_sharded_split_embedding_weights(
self._table_name_to_weight_count_per_rank,
use_param_size_as_rows=True,
)
# pyre-ignore
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
metadata_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy,
metadata_list, # pyre-ignore [6]
self._pg,
prefix,
self._table_name_to_weight_count_per_rank,
)

assert (
len(pmt_list)
== len(weight_ids_list) # pyre-ignore
== len(bucket_cnt_list) # pyre-ignore
== len(metadata_list) # pyre-ignore
)
assert (
len(pmt_sharded_t_list)
== len(weight_id_sharded_t_list)
== len(bucket_cnt_sharded_t_list)
== len(metadata_sharded_t_list)
)
self._split_weights_res = (
pmt_sharded_t_list,
weight_id_sharded_t_list,
bucket_cnt_sharded_t_list,
metadata_sharded_t_list,
)

def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
Expand All @@ -2408,6 +2452,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand All @@ -2416,6 +2461,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
optional ShardedTensor for weight_id
optional ShardedTensor for bucket_cnt
optional ShardedTensor for metadata
"""
self._init_sharded_split_embedding_weights()
# pyre-ignore[16]
Expand All @@ -2424,13 +2470,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
pmt_sharded_t_list = self._split_weights_res[0]
weight_id_sharded_t_list = self._split_weights_res[1]
bucket_cnt_sharded_t_list = self._split_weights_res[2]
metadata_sharded_t_list = self._split_weights_res[3]
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
table_config = self._config.embedding_tables[table_idx]
key = append_prefix(prefix, f"{table_config.name}")

yield key, pmt_sharded_t, weight_id_sharded_t_list[
table_idx
], bucket_cnt_sharded_t_list[table_idx]
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]

def flush(self) -> None:
"""
Expand Down Expand Up @@ -2459,6 +2506,7 @@ def split_embedding_weights(
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)

Expand Down
13 changes: 13 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,13 @@ def _pre_load_state_dict_hook(
weight_key = f"{prefix}embeddings.{table_name}.weight"
weight_id_key = f"{prefix}embeddings.{table_name}.weight_id"
bucket_key = f"{prefix}embeddings.{table_name}.bucket"
metadata_key = f"{prefix}embeddings.{table_name}.metadata"
if weight_id_key in state_dict:
del state_dict[weight_id_key]
if bucket_key in state_dict:
del state_dict[bucket_key]
if metadata_key in state_dict:
del state_dict[metadata_key]
assert weight_key in state_dict
assert (
len(self._model_parallel_name_to_local_shards[table_name]) == 1
Expand Down Expand Up @@ -1037,6 +1040,7 @@ def post_state_dict_hook(
weights_t,
weight_ids_sharded_t,
id_cnt_per_bucket_sharded_t,
metadata_sharded_t,
) in (
lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore
):
Expand All @@ -1048,19 +1052,22 @@ def post_state_dict_hook(
assert (
weight_ids_sharded_t is not None
and id_cnt_per_bucket_sharded_t is not None
and metadata_sharded_t is not None
)
# The logic here assumes there is only one shard per table on any particular rank
# if there are cases each rank has >1 shards, we need to update here accordingly
sharded_kvtensors_copy[table_name] = weights_t
virtual_table_sharded_t_map[table_name] = (
weight_ids_sharded_t,
id_cnt_per_bucket_sharded_t,
metadata_sharded_t,
)
else:
assert isinstance(weights_t, PartiallyMaterializedTensor)
assert (
weight_ids_sharded_t is None
and id_cnt_per_bucket_sharded_t is None
and metadata_sharded_t is None
)
# The logic here assumes there is only one shard per table on any particular rank
# if there are cases each rank has >1 shards, we need to update here accordingly
Expand Down Expand Up @@ -1099,6 +1106,12 @@ def update_destination(
destination,
virtual_table_sharded_t_map[table_name][1],
)
update_destination(
table_name,
"metadata",
destination,
virtual_table_sharded_t_map[table_name][2],
)

def _post_load_state_dict_hook(
module: "ShardedEmbeddingCollection",
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def get_named_split_embedding_weights_snapshot(
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand Down Expand Up @@ -732,6 +733,7 @@ def get_named_split_embedding_weights_snapshot(
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Expand Down
Loading
Loading