Skip to content

Fix code quality score #3255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 101 additions & 36 deletions torchrec/modules/hash_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
import dataclasses
import logging
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
Expand Down Expand Up @@ -41,6 +42,21 @@ def _tensor_may_to_device(
return (src, src_device)


@dataclasses.dataclass
class TrainInputMapperConfig:
"""
Configuration for TrainInputMapper.

Args:
inference_dispatch_div_train_world_size: the flag to control whether to divide input by
world_size https://fburl.com/code/c9x98073
name: the name of the embedding table
"""

inference_dispatch_div_train_world_size: bool = False
name: Optional[str] = None


class TrainInputMapper(torch.nn.Module):
"""
Module used to generate sizes and offsets information corresponding to
Expand All @@ -54,12 +70,13 @@ class TrainInputMapper(torch.nn.Module):
total_num_buckets: the total number of buckets across all ranks at training time
size_per_rank: the size of the identity tensor/embedding size per rank
train_rank_offsets: the offset of the embedding table indices per rank
inference_dispatch_div_train_world_size: the flag to control whether to divide input by
world_size https://fburl.com/code/c9x98073
name: the name of the embedding table
config: configuration object containing additional parameters

Example::
mapper = TrainInputMapper(...)
config = TrainInputMapperConfig(inference_dispatch_div_train_world_size=False)
mapper = TrainInputMapper(input_hash_size=1024, total_num_buckets=8,
size_per_rank=size_tensor, train_rank_offsets=offset_tensor,
config=config)
mapper(values, output_offset)
"""

Expand All @@ -69,18 +86,26 @@ def __init__(
total_num_buckets: int,
size_per_rank: torch.Tensor,
train_rank_offsets: torch.Tensor,
inference_dispatch_div_train_world_size: bool = False,
name: Optional[str] = None,
config: Optional[TrainInputMapperConfig] = None,
) -> None:
# Initialize config with defaults if not provided
if config is None:
config = TrainInputMapperConfig()

# Extract values from config
inference_dispatch_div_train_world_size = (
config.inference_dispatch_div_train_world_size
)
name = config.name
super().__init__()

self._input_hash_size = input_hash_size
assert total_num_buckets > 0, f"{total_num_buckets=} must be positive"
self._buckets = total_num_buckets
self._inference_dispatch_div_train_world_size = (
self._inference_dispatch_div_train_world_size: bool = (
inference_dispatch_div_train_world_size
)
self._name = name
self._name: Optional[str] = name
self.register_buffer(
"_zch_size_per_training_rank", size_per_rank, persistent=False
)
Expand Down Expand Up @@ -152,15 +177,12 @@ def _get_device(hash_zch_identities: torch.Tensor) -> torch.device:
return hash_zch_identities.device


class HashZchManagedCollisionModule(ManagedCollisionModule):
@dataclasses.dataclass
class HashZchConfig:
"""
Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors.
Configuration for HashZchManagedCollisionModule.

Args:
zch_size: local size of the embedding table
device: the compute device
total_num_buckets: logical shard within each rank for resharding purpose, note that
1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size
max_probe: the number of times MPZCH kernel attempts to run linear search for lookup or insertion
input_hash_size: the max size of input IDs (default to 0)
output_segments: the index range of each bucket, which is computed before sharding and typically not provided by user
Expand All @@ -175,9 +197,37 @@ class HashZchManagedCollisionModule(ManagedCollisionModule):
end_bucket: end bucket of the current rank, typically not provided by user
opt_in_prob: the probability of an ID to be opted in from a statistical aspect
percent_reserved_slots: percentage of slots to be reserved when opt-in is enabled, the value must be in [0, 100)
"""

max_probe: int = 128
input_hash_size: int = 0
output_segments: Optional[List[int]] = None
is_inference: bool = False
name: Optional[str] = None
tb_logging_frequency: int = 0
eviction_policy_name: Optional[HashZchEvictionPolicyName] = None
eviction_config: Optional[HashZchEvictionConfig] = None
inference_dispatch_div_train_world_size: bool = False
start_bucket: int = 0
end_bucket: Optional[int] = None
opt_in_prob: int = -1
percent_reserved_slots: float = 0


class HashZchManagedCollisionModule(ManagedCollisionModule):
"""
Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors.

Args:
zch_size: local size of the embedding table
device: the compute device
total_num_buckets: logical shard within each rank for resharding purpose, note that
1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size
config: configuration object containing additional parameters

Example::
module = HashZchManagedCollisionModule(...)
config = HashZchConfig(max_probe=128, input_hash_size=0)
module = HashZchManagedCollisionModule(zch_size=1024, device=torch.device("cuda"), total_num_buckets=8, config=config)
module(features)
"""

Expand All @@ -198,20 +248,28 @@ def __init__(
zch_size: int,
device: torch.device,
total_num_buckets: int,
max_probe: int = 128,
input_hash_size: int = 0,
output_segments: Optional[List[int]] = None,
is_inference: bool = False,
name: Optional[str] = None,
tb_logging_frequency: int = 0,
eviction_policy_name: Optional[HashZchEvictionPolicyName] = None,
eviction_config: Optional[HashZchEvictionConfig] = None,
inference_dispatch_div_train_world_size: bool = False,
start_bucket: int = 0,
end_bucket: Optional[int] = None,
opt_in_prob: int = -1,
percent_reserved_slots: float = 0,
config: Optional[HashZchConfig] = None,
) -> None:
# Initialize config with defaults if not provided
if config is None:
config = HashZchConfig()

# Extract values from config
max_probe = config.max_probe
input_hash_size = config.input_hash_size
output_segments = config.output_segments
is_inference = config.is_inference
name = config.name
tb_logging_frequency = config.tb_logging_frequency
eviction_policy_name = config.eviction_policy_name
eviction_config = config.eviction_config
inference_dispatch_div_train_world_size = (
config.inference_dispatch_div_train_world_size
)
start_bucket = config.start_bucket
end_bucket = config.end_bucket
opt_in_prob = config.opt_in_prob
percent_reserved_slots = config.percent_reserved_slots
if output_segments is None:
assert (
zch_size % total_num_buckets == 0
Expand Down Expand Up @@ -300,7 +358,7 @@ def __init__(
self._hash_zch_identities = torch.nn.Parameter(identities, requires_grad=False)
self.register_buffer(HashZchManagedCollisionModule.METADATA_BUFFER, metadata)

self._max_probe = max_probe
self._max_probe: int = max_probe
self._buckets = total_num_buckets
# Do not need to store in buffer since this is created and consumed
# at each step https://fburl.com/code/axzimmbx
Expand All @@ -311,16 +369,19 @@ def __init__(
torch.tensor(self._output_segments, dtype=torch.int64)
)

mapper_config = TrainInputMapperConfig(
inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size,
name=self._name,
)

self.input_mapper: torch.nn.Module = TrainInputMapper(
input_hash_size=self._input_hash_size,
total_num_buckets=total_num_buckets,
size_per_rank=size_per_rank,
train_rank_offsets=torch.tensor(
torch.ops.fbgemm.asynchronous_exclusive_cumsum(size_per_rank)
),
# be consistent with https://fburl.com/code/p4mj4mc1
inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size,
name=self._name,
config=mapper_config,
)

if self._is_inference is True:
Expand Down Expand Up @@ -607,11 +668,8 @@ def rebuild_with_output_id_range(
)
new_zch_size = output_id_range[1] - output_id_range[0]

return self.__class__(
zch_size=new_zch_size,
device=device or self.device,
config = HashZchConfig(
max_probe=self._max_probe,
total_num_buckets=self._buckets,
input_hash_size=self._input_hash_size,
is_inference=self._is_inference,
start_bucket=start_idx,
Expand All @@ -625,6 +683,13 @@ def rebuild_with_output_id_range(
percent_reserved_slots=self._percent_reserved_slots,
)

return self.__class__(
zch_size=new_zch_size,
device=device or self.device,
total_num_buckets=self._buckets,
config=config,
)


@torch.fx.wrap
def _append_eviction_indice(
Expand Down
Loading