diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index dcb2dfc33..e86ac770e 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -1192,16 +1192,9 @@ def calculate_shard_storages( hbm_storage: int = tensor_storage.get("hbm", 0) ddr_storage: int = tensor_storage.get("ddr", 0) - table_cached: bool = False - if compute_kernel in { - EmbeddingComputeKernel.FUSED_UVM_CACHING.value, - EmbeddingComputeKernel.QUANT_UVM_CACHING.value, - EmbeddingComputeKernel.KEY_VALUE.value, - EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, - EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, - }: + table_cached = _is_table_cached(compute_kernel) + if table_cached: hbm_storage = round(ddr_storage * caching_ratio) - table_cached = True optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0] @@ -1293,6 +1286,20 @@ def calculate_shard_storages( ] +def _is_table_cached( + compute_kernel: str, +) -> bool: + if compute_kernel in { + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.QUANT_UVM_CACHING.value, + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, + }: + return True + return False + + def _calculate_shard_io_sizes( sharding_type: str, batch_sizes: List[int], @@ -1554,27 +1561,20 @@ def _calculate_storage_specific_sizes( is_inference: bool = False, clf: Optional[float] = None, ) -> List[int]: - tensor_sizes: List[int] = [ - ( - math.ceil(storage * prod(size) / prod(shape)) - if sharding_type != ShardingType.DATA_PARALLEL.value - else storage - ) - for size in shard_sizes - ] - optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape) - - optimizer_sizes: List[int] = [ - math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes - ] - - # If a table has turned on UVM caching (meaning clf is not None), there'll be - # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to - # cache aux state (note that this is not the cache content itself) - cache_aux_state_sizes: List[int] = ( - [0] * len(shard_sizes) - if clf is None - else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes] + tensor_sizes: List[int] = _calculate_tensor_sizes( + storage, + shape, + shard_sizes, + sharding_type, + ) + optimizer_sizes = _calculate_optimizer_sizes( + tensor_sizes, + optimizer_class, + shape, + ) + cache_aux_state_sizes: List[int] = _calculate_cache_aux_state_sizes( + shard_sizes, + clf, ) return [ @@ -1589,6 +1589,45 @@ def _calculate_storage_specific_sizes( ] +def _calculate_tensor_sizes( + storage: int, shape: torch.Size, shard_sizes: List[List[int]], sharding_type: str +) -> List[int]: + return [ + ( + math.ceil(storage * prod(size) / prod(shape)) + if sharding_type != ShardingType.DATA_PARALLEL.value + else storage + ) + for size in shard_sizes + ] + + +# If a table has turned on UVM caching (meaning clf is not None), there'll be +# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to +# cache aux state (note that this is not the cache content itself) +def _calculate_cache_aux_state_sizes( + shard_sizes: List[List[int]], clf: Optional[float] +) -> List[int]: + if clf is None: + return [0] * len(shard_sizes) + return [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes] + + +def _calculate_optimizer_sizes( + tensor_sizes: List[int], + optimizer_class: Optional[Type[torch.optim.Optimizer]], + sharding_tensor_shape: torch.Size, +) -> List[int]: + optimizer_multiplier: float = _get_optimizer_multipler( + optimizer_class, + sharding_tensor_shape, + ) + optimizer_sizes: List[int] = [ + math.ceil(tensor_size * optimizer_multiplier) for tensor_size in tensor_sizes + ] + return optimizer_sizes + + def _get_optimizer_multipler( optimizer_class: Optional[Type[torch.optim.Optimizer]], shape: torch.Size,