diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py index 30b0a8b8f..395b23e89 100644 --- a/torchrec/modules/hash_mc_modules.py +++ b/torchrec/modules/hash_mc_modules.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import dataclasses import logging import math from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -41,6 +42,21 @@ def _tensor_may_to_device( return (src, src_device) +@dataclasses.dataclass +class TrainInputMapperConfig: + """ + Configuration for TrainInputMapper. + + Args: + inference_dispatch_div_train_world_size: the flag to control whether to divide input by + world_size https://fburl.com/code/c9x98073 + name: the name of the embedding table + """ + + inference_dispatch_div_train_world_size: bool = False + name: Optional[str] = None + + class TrainInputMapper(torch.nn.Module): """ Module used to generate sizes and offsets information corresponding to @@ -54,12 +70,13 @@ class TrainInputMapper(torch.nn.Module): total_num_buckets: the total number of buckets across all ranks at training time size_per_rank: the size of the identity tensor/embedding size per rank train_rank_offsets: the offset of the embedding table indices per rank - inference_dispatch_div_train_world_size: the flag to control whether to divide input by - world_size https://fburl.com/code/c9x98073 - name: the name of the embedding table + config: configuration object containing additional parameters Example:: - mapper = TrainInputMapper(...) + config = TrainInputMapperConfig(inference_dispatch_div_train_world_size=False) + mapper = TrainInputMapper(input_hash_size=1024, total_num_buckets=8, + size_per_rank=size_tensor, train_rank_offsets=offset_tensor, + config=config) mapper(values, output_offset) """ @@ -69,18 +86,26 @@ def __init__( total_num_buckets: int, size_per_rank: torch.Tensor, train_rank_offsets: torch.Tensor, - inference_dispatch_div_train_world_size: bool = False, - name: Optional[str] = None, + config: Optional[TrainInputMapperConfig] = None, ) -> None: + # Initialize config with defaults if not provided + if config is None: + config = TrainInputMapperConfig() + + # Extract values from config + inference_dispatch_div_train_world_size = ( + config.inference_dispatch_div_train_world_size + ) + name = config.name super().__init__() self._input_hash_size = input_hash_size assert total_num_buckets > 0, f"{total_num_buckets=} must be positive" self._buckets = total_num_buckets - self._inference_dispatch_div_train_world_size = ( + self._inference_dispatch_div_train_world_size: bool = ( inference_dispatch_div_train_world_size ) - self._name = name + self._name: Optional[str] = name self.register_buffer( "_zch_size_per_training_rank", size_per_rank, persistent=False ) @@ -152,15 +177,12 @@ def _get_device(hash_zch_identities: torch.Tensor) -> torch.device: return hash_zch_identities.device -class HashZchManagedCollisionModule(ManagedCollisionModule): +@dataclasses.dataclass +class HashZchConfig: """ - Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors. + Configuration for HashZchManagedCollisionModule. Args: - zch_size: local size of the embedding table - device: the compute device - total_num_buckets: logical shard within each rank for resharding purpose, note that - 1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size max_probe: the number of times MPZCH kernel attempts to run linear search for lookup or insertion input_hash_size: the max size of input IDs (default to 0) output_segments: the index range of each bucket, which is computed before sharding and typically not provided by user @@ -175,9 +197,37 @@ class HashZchManagedCollisionModule(ManagedCollisionModule): end_bucket: end bucket of the current rank, typically not provided by user opt_in_prob: the probability of an ID to be opted in from a statistical aspect percent_reserved_slots: percentage of slots to be reserved when opt-in is enabled, the value must be in [0, 100) + """ + + max_probe: int = 128 + input_hash_size: int = 0 + output_segments: Optional[List[int]] = None + is_inference: bool = False + name: Optional[str] = None + tb_logging_frequency: int = 0 + eviction_policy_name: Optional[HashZchEvictionPolicyName] = None + eviction_config: Optional[HashZchEvictionConfig] = None + inference_dispatch_div_train_world_size: bool = False + start_bucket: int = 0 + end_bucket: Optional[int] = None + opt_in_prob: int = -1 + percent_reserved_slots: float = 0 + + +class HashZchManagedCollisionModule(ManagedCollisionModule): + """ + Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors. + + Args: + zch_size: local size of the embedding table + device: the compute device + total_num_buckets: logical shard within each rank for resharding purpose, note that + 1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size + config: configuration object containing additional parameters Example:: - module = HashZchManagedCollisionModule(...) + config = HashZchConfig(max_probe=128, input_hash_size=0) + module = HashZchManagedCollisionModule(zch_size=1024, device=torch.device("cuda"), total_num_buckets=8, config=config) module(features) """ @@ -198,20 +248,28 @@ def __init__( zch_size: int, device: torch.device, total_num_buckets: int, - max_probe: int = 128, - input_hash_size: int = 0, - output_segments: Optional[List[int]] = None, - is_inference: bool = False, - name: Optional[str] = None, - tb_logging_frequency: int = 0, - eviction_policy_name: Optional[HashZchEvictionPolicyName] = None, - eviction_config: Optional[HashZchEvictionConfig] = None, - inference_dispatch_div_train_world_size: bool = False, - start_bucket: int = 0, - end_bucket: Optional[int] = None, - opt_in_prob: int = -1, - percent_reserved_slots: float = 0, + config: Optional[HashZchConfig] = None, ) -> None: + # Initialize config with defaults if not provided + if config is None: + config = HashZchConfig() + + # Extract values from config + max_probe = config.max_probe + input_hash_size = config.input_hash_size + output_segments = config.output_segments + is_inference = config.is_inference + name = config.name + tb_logging_frequency = config.tb_logging_frequency + eviction_policy_name = config.eviction_policy_name + eviction_config = config.eviction_config + inference_dispatch_div_train_world_size = ( + config.inference_dispatch_div_train_world_size + ) + start_bucket = config.start_bucket + end_bucket = config.end_bucket + opt_in_prob = config.opt_in_prob + percent_reserved_slots = config.percent_reserved_slots if output_segments is None: assert ( zch_size % total_num_buckets == 0 @@ -300,7 +358,7 @@ def __init__( self._hash_zch_identities = torch.nn.Parameter(identities, requires_grad=False) self.register_buffer(HashZchManagedCollisionModule.METADATA_BUFFER, metadata) - self._max_probe = max_probe + self._max_probe: int = max_probe self._buckets = total_num_buckets # Do not need to store in buffer since this is created and consumed # at each step https://fburl.com/code/axzimmbx @@ -311,6 +369,11 @@ def __init__( torch.tensor(self._output_segments, dtype=torch.int64) ) + mapper_config = TrainInputMapperConfig( + inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size, + name=self._name, + ) + self.input_mapper: torch.nn.Module = TrainInputMapper( input_hash_size=self._input_hash_size, total_num_buckets=total_num_buckets, @@ -318,9 +381,7 @@ def __init__( train_rank_offsets=torch.tensor( torch.ops.fbgemm.asynchronous_exclusive_cumsum(size_per_rank) ), - # be consistent with https://fburl.com/code/p4mj4mc1 - inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size, - name=self._name, + config=mapper_config, ) if self._is_inference is True: @@ -607,11 +668,8 @@ def rebuild_with_output_id_range( ) new_zch_size = output_id_range[1] - output_id_range[0] - return self.__class__( - zch_size=new_zch_size, - device=device or self.device, + config = HashZchConfig( max_probe=self._max_probe, - total_num_buckets=self._buckets, input_hash_size=self._input_hash_size, is_inference=self._is_inference, start_bucket=start_idx, @@ -625,6 +683,13 @@ def rebuild_with_output_id_range( percent_reserved_slots=self._percent_reserved_slots, ) + return self.__class__( + zch_size=new_zch_size, + device=device or self.device, + total_num_buckets=self._buckets, + config=config, + ) + @torch.fx.wrap def _append_eviction_indice(