diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1c8536086..9dbfd142b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -958,7 +958,7 @@ def _gen_named_parameters_by_table_ssd_pmt( name as well as the parameter itself. The embedding table is in the form of PartiallyMaterializedTensor to support windowed access. """ - pmts, _, _ = emb_module.split_embedding_weights() + pmts, _, _, _ = emb_module.split_embedding_weights() for table_config, pmt in zip(config.embedding_tables, pmts): table_name = table_config.name emb_table = pmt @@ -1272,7 +1272,7 @@ def state_dict( # in the case no_snapshot=False, a flush is required. we rely on the flush operation in # ShardedEmbeddingBagCollection._pre_state_dict_hook() - emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: emb_table.local_metadata.placement._device = torch.device("cpu") @@ -1322,6 +1322,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ @@ -1330,13 +1331,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat RocksDB snapshot to support windowed access. optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch + optional ShardedTensor for metadata, this won't be used here as this is non-kvzch """ for config, tensor in zip( self._config.embedding_tables, self.split_embedding_weights(no_snapshot=False)[0], ): key = append_prefix(prefix, f"{config.name}") - yield key, tensor, None, None + yield key, tensor, None, None, None def flush(self) -> None: """ @@ -1364,6 +1366,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ List[PartiallyMaterializedTensor], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], ]: # pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor], # Optional[List[Tensor]], Optional[List[Tensor]]]` but got @@ -1415,6 +1418,7 @@ def __init__( List[ShardedTensor], List[ShardedTensor], List[ShardedTensor], + List[ShardedTensor], ] ] = None @@ -1490,7 +1494,7 @@ def state_dict( # in the case no_snapshot=False, a flush is required. we rely on the flush operation in # ShardedEmbeddingBagCollection._pre_state_dict_hook() - emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: emb_table.local_metadata.placement._device = torch.device("cpu") @@ -1546,8 +1550,10 @@ def _init_sharded_split_embedding_weights( if not force_regenerate and self._split_weights_res is not None: return - pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( - no_snapshot=False, + pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = ( + self.split_embedding_weights( + no_snapshot=False, + ) ) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: @@ -1581,17 +1587,31 @@ def _init_sharded_split_embedding_weights( self._table_name_to_weight_count_per_rank, use_param_size_as_rows=True, ) - # pyre-ignore - assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) + metadata_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + metadata_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + + assert ( + len(pmt_list) + == len(weight_ids_list) # pyre-ignore + == len(bucket_cnt_list) # pyre-ignore + == len(metadata_list) # pyre-ignore + ) assert ( len(pmt_sharded_t_list) == len(weight_id_sharded_t_list) == len(bucket_cnt_sharded_t_list) + == len(metadata_sharded_t_list) ) self._split_weights_res = ( pmt_sharded_t_list, weight_id_sharded_t_list, bucket_cnt_sharded_t_list, + metadata_sharded_t_list, ) def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ @@ -1600,6 +1620,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ @@ -1608,6 +1629,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat PMT for embedding table with a valid RocksDB snapshot to support tensor IO optional ShardedTensor for weight_id optional ShardedTensor for bucket_cnt + optional ShardedTensor for metadata """ self._init_sharded_split_embedding_weights() # pyre-ignore[16] @@ -1616,13 +1638,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat pmt_sharded_t_list = self._split_weights_res[0] weight_id_sharded_t_list = self._split_weights_res[1] bucket_cnt_sharded_t_list = self._split_weights_res[2] + metadata_sharded_t_list = self._split_weights_res[3] for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): table_config = self._config.embedding_tables[table_idx] key = append_prefix(prefix, f"{table_config.name}") yield key, pmt_sharded_t, weight_id_sharded_t_list[ table_idx - ], bucket_cnt_sharded_t_list[table_idx] + ], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx] def flush(self) -> None: """ @@ -1651,6 +1674,7 @@ def split_embedding_weights( Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], ]: return self.emb_module.split_embedding_weights(no_snapshot, should_flush) @@ -2079,7 +2103,7 @@ def state_dict( # in the case no_snapshot=False, a flush is required. we rely on the flush operation in # ShardedEmbeddingBagCollection._pre_state_dict_hook() - emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: emb_table.local_metadata.placement._device = torch.device("cpu") @@ -2129,6 +2153,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ @@ -2137,13 +2162,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat RocksDB snapshot to support windowed access. optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch + optional ShardedTensor for metadata, this won't be used here as this is non-kvzch """ for config, tensor in zip( self._config.embedding_tables, self.split_embedding_weights(no_snapshot=False)[0], ): key = append_prefix(prefix, f"{config.name}") - yield key, tensor, None, None + yield key, tensor, None, None, None def flush(self) -> None: """ @@ -2170,6 +2196,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ List[PartiallyMaterializedTensor], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], ]: # pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor], # Optional[List[Tensor]], Optional[List[Tensor]]]` but got @@ -2223,6 +2250,7 @@ def __init__( List[ShardedTensor], List[ShardedTensor], List[ShardedTensor], + List[ShardedTensor], ] ] = None @@ -2298,7 +2326,7 @@ def state_dict( # in the case no_snapshot=False, a flush is required. we rely on the flush operation in # ShardedEmbeddingBagCollection._pre_state_dict_hook() - emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: emb_table.local_metadata.placement._device = torch.device("cpu") @@ -2354,8 +2382,10 @@ def _init_sharded_split_embedding_weights( if not force_regenerate and self._split_weights_res is not None: return - pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( - no_snapshot=False, + pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = ( + self.split_embedding_weights( + no_snapshot=False, + ) ) emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) for emb_table in emb_table_config_copy: @@ -2389,17 +2419,31 @@ def _init_sharded_split_embedding_weights( self._table_name_to_weight_count_per_rank, use_param_size_as_rows=True, ) - # pyre-ignore - assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) + metadata_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + metadata_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + + assert ( + len(pmt_list) + == len(weight_ids_list) # pyre-ignore + == len(bucket_cnt_list) # pyre-ignore + == len(metadata_list) # pyre-ignore + ) assert ( len(pmt_sharded_t_list) == len(weight_id_sharded_t_list) == len(bucket_cnt_sharded_t_list) + == len(metadata_sharded_t_list) ) self._split_weights_res = ( pmt_sharded_t_list, weight_id_sharded_t_list, bucket_cnt_sharded_t_list, + metadata_sharded_t_list, ) def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ @@ -2408,6 +2452,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ @@ -2416,6 +2461,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat PMT for embedding table with a valid RocksDB snapshot to support tensor IO optional ShardedTensor for weight_id optional ShardedTensor for bucket_cnt + optional ShardedTensor for metadata """ self._init_sharded_split_embedding_weights() # pyre-ignore[16] @@ -2424,13 +2470,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat pmt_sharded_t_list = self._split_weights_res[0] weight_id_sharded_t_list = self._split_weights_res[1] bucket_cnt_sharded_t_list = self._split_weights_res[2] + metadata_sharded_t_list = self._split_weights_res[3] for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): table_config = self._config.embedding_tables[table_idx] key = append_prefix(prefix, f"{table_config.name}") yield key, pmt_sharded_t, weight_id_sharded_t_list[ table_idx - ], bucket_cnt_sharded_t_list[table_idx] + ], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx] def flush(self) -> None: """ @@ -2459,6 +2506,7 @@ def split_embedding_weights( Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], ]: return self.emb_module.split_embedding_weights(no_snapshot, should_flush) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index ef6a67098..dbb74459f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -698,10 +698,13 @@ def _pre_load_state_dict_hook( weight_key = f"{prefix}embeddings.{table_name}.weight" weight_id_key = f"{prefix}embeddings.{table_name}.weight_id" bucket_key = f"{prefix}embeddings.{table_name}.bucket" + metadata_key = f"{prefix}embeddings.{table_name}.metadata" if weight_id_key in state_dict: del state_dict[weight_id_key] if bucket_key in state_dict: del state_dict[bucket_key] + if metadata_key in state_dict: + del state_dict[metadata_key] assert weight_key in state_dict assert ( len(self._model_parallel_name_to_local_shards[table_name]) == 1 @@ -1037,6 +1040,7 @@ def post_state_dict_hook( weights_t, weight_ids_sharded_t, id_cnt_per_bucket_sharded_t, + metadata_sharded_t, ) in ( lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore ): @@ -1048,6 +1052,7 @@ def post_state_dict_hook( assert ( weight_ids_sharded_t is not None and id_cnt_per_bucket_sharded_t is not None + and metadata_sharded_t is not None ) # The logic here assumes there is only one shard per table on any particular rank # if there are cases each rank has >1 shards, we need to update here accordingly @@ -1055,12 +1060,14 @@ def post_state_dict_hook( virtual_table_sharded_t_map[table_name] = ( weight_ids_sharded_t, id_cnt_per_bucket_sharded_t, + metadata_sharded_t, ) else: assert isinstance(weights_t, PartiallyMaterializedTensor) assert ( weight_ids_sharded_t is None and id_cnt_per_bucket_sharded_t is None + and metadata_sharded_t is None ) # The logic here assumes there is only one shard per table on any particular rank # if there are cases each rank has >1 shards, we need to update here accordingly @@ -1099,6 +1106,12 @@ def update_destination( destination, virtual_table_sharded_t_map[table_name][1], ) + update_destination( + table_name, + "metadata", + destination, + virtual_table_sharded_t_map[table_name][2], + ) def _post_load_state_dict_hook( module: "ShardedEmbeddingCollection", diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 248f133ac..46052c128 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -381,6 +381,7 @@ def get_named_split_embedding_weights_snapshot( Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ @@ -732,6 +733,7 @@ def get_named_split_embedding_weights_snapshot( Union[ShardedTensor, PartiallyMaterializedTensor], Optional[ShardedTensor], Optional[ShardedTensor], + Optional[ShardedTensor], ] ]: """ diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2a7f2fa39..1481d2bbb 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -820,10 +820,13 @@ def _pre_load_state_dict_hook( weight_key = f"{prefix}embedding_bags.{table_name}.weight" weight_id_key = f"{prefix}embedding_bags.{table_name}.weight_id" bucket_key = f"{prefix}embedding_bags.{table_name}.bucket" + metadata_key = f"{prefix}embedding_bags.{table_name}.metadata" if weight_id_key in state_dict: del state_dict[weight_id_key] if bucket_key in state_dict: del state_dict[bucket_key] + if metadata_key in state_dict: + del state_dict[metadata_key] assert weight_key in state_dict assert ( len(self._model_parallel_name_to_local_shards[table_name]) == 1 @@ -1196,6 +1199,7 @@ def post_state_dict_hook( weights_t, weight_ids_sharded_t, id_cnt_per_bucket_sharded_t, + metadata_sharded_t, ) in ( lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore ): @@ -1207,6 +1211,7 @@ def post_state_dict_hook( assert ( weight_ids_sharded_t is not None and id_cnt_per_bucket_sharded_t is not None + and metadata_sharded_t is not None ) # The logic here assumes there is only one shard per table on any particular rank # if there are cases each rank has >1 shards, we need to update here accordingly @@ -1214,12 +1219,14 @@ def post_state_dict_hook( virtual_table_sharded_t_map[table_name] = ( weight_ids_sharded_t, id_cnt_per_bucket_sharded_t, + metadata_sharded_t, ) else: assert isinstance(weights_t, PartiallyMaterializedTensor) assert ( weight_ids_sharded_t is None and id_cnt_per_bucket_sharded_t is None + and metadata_sharded_t is None ) # The logic here assumes there is only one shard per table on any particular rank # if there are cases each rank has >1 shards, we need to update here accordingly @@ -1258,6 +1265,12 @@ def update_destination( destination, virtual_table_sharded_t_map[table_name][1], ) + update_destination( + table_name, + "metadata", + destination, + virtual_table_sharded_t_map[table_name][2], + ) def _post_load_state_dict_hook( module: "ShardedEmbeddingBagCollection", diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index 9662e146c..2d253cb47 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -120,13 +120,14 @@ def _copy_ssd_emb_modules( emb1_kv = { t: pmt - for t, pmt, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() + for t, pmt, _, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, pmt, _, _, + _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): pmt1 = emb1_kv[t] w1 = pmt1.full_tensor() @@ -760,13 +761,14 @@ def _copy_ssd_emb_modules( emb1_kv = { t: pmt - for t, pmt, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() + for t, pmt, _, _, _ in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, pmt, _, _, + _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): pmt1 = emb1_kv[t] w1 = pmt1.full_tensor() @@ -900,14 +902,15 @@ def _copy_ssd_emb_modules( emb_module2.flush() emb1_kv = { - t: (sharded_t, sharded_w_id, bucket) - for t, sharded_t, sharded_w_id, bucket in emb_module1.get_named_split_embedding_weights_snapshot() + t: (sharded_t, sharded_w_id, bucket, metadata) + for t, sharded_t, sharded_w_id, bucket, metadata in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, sharded_t2, _, _, + _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): assert t in emb1_kv sharded_t1 = emb1_kv[t][0] @@ -955,6 +958,7 @@ def _copy_fused_modules_into_ssd_emb_modules( sharded_t, _, _, + _, ) in ssd_emb_module.get_named_split_embedding_weights_snapshot(): weight_key = f"{t}.weight" fused_sharded_t = fused_state_dict[weight_key] @@ -1367,14 +1371,15 @@ def _copy_ssd_emb_modules( emb_module2.flush() emb1_kv = { - t: (sharded_t, sharded_w_id, bucket) - for t, sharded_t, sharded_w_id, bucket in emb_module1.get_named_split_embedding_weights_snapshot() + t: (sharded_t, sharded_w_id, bucket, metadata) + for t, sharded_t, sharded_w_id, bucket, metadata in emb_module1.get_named_split_embedding_weights_snapshot() } for ( t, sharded_t2, _, _, + _, ) in emb_module2.get_named_split_embedding_weights_snapshot(): assert t in emb1_kv sharded_t1 = emb1_kv[t][0] @@ -1423,6 +1428,7 @@ def _copy_fused_modules_into_ssd_emb_modules( sharded_t, _, _, + _, ) in ssd_emb_module.get_named_split_embedding_weights_snapshot(): weight_key = f"{t}.weight" fused_sharded_t = fused_state_dict[weight_key] diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index d477b6b26..90d24e85d 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -194,6 +194,9 @@ class CountBasedEvictionPolicy(VirtualTableEvictionPolicy): 15 # eviction threshold for count based eviction policy. 0 means no eviction ) decay_rate: float = 0.99 # default decay by default + inference_eviction_threshold: int = ( + eviction_threshold # eviction threshold for inference count based eviction policy. 0 means no eviction + ) @dataclass @@ -203,6 +206,7 @@ class TimestampBasedEvictionPolicy(VirtualTableEvictionPolicy): """ eviction_ttl_mins: int = 24 * 60 # 1 day. 0 means no eviction + inference_eviction_ttl_mins: int = eviction_ttl_mins # 0 means no eviction @dataclass @@ -216,6 +220,13 @@ class CountTimestampMixedEvictionPolicy(VirtualTableEvictionPolicy): ) decay_rate: float = 0.99 # default decay by default eviction_ttl_mins: int = 24 * 60 # 1 day. 0 means no eviction based on timestamp + inference_eviction_threshold: int = ( + eviction_threshold # eviction threshold for inference count based eviction policy. 0 means no eviction based on count + ) + + inference_eviction_ttl_mins: int = ( + eviction_ttl_mins # 0 means no eviction based on timestamp + ) @dataclass @@ -227,6 +238,7 @@ class FeatureL2NormBasedEvictionPolicy(VirtualTableEvictionPolicy): eviction_threshold: float = ( 0.0 # eviction threshold for feature l2 norm based eviction policy. 0.0 means no eviction ) + inference_eviction_threshold: float = eviction_threshold @dataclass