Skip to content

Commit e9661ec

Browse files
Ahmed Shuaibifacebook-github-bot
authored andcommitted
refactor: create functions for shard/tensor size calculations
Summary: - refactor to create function for checking if table is cached - refactor to create functions for tensor size calculations Differential Revision: D79007077
1 parent 40d8fb0 commit e9661ec

File tree

1 file changed

+53
-25
lines changed

1 file changed

+53
-25
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,16 +1192,9 @@ def calculate_shard_storages(
11921192
hbm_storage: int = tensor_storage.get("hbm", 0)
11931193
ddr_storage: int = tensor_storage.get("ddr", 0)
11941194

1195-
table_cached: bool = False
1196-
if compute_kernel in {
1197-
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1198-
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1199-
EmbeddingComputeKernel.KEY_VALUE.value,
1200-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1201-
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1202-
}:
1195+
table_cached = _is_table_cached(compute_kernel)
1196+
if table_cached:
12031197
hbm_storage = round(ddr_storage * caching_ratio)
1204-
table_cached = True
12051198

12061199
optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0]
12071200

@@ -1293,6 +1286,20 @@ def calculate_shard_storages(
12931286
]
12941287

12951288

1289+
def _is_table_cached(
1290+
compute_kernel: str,
1291+
) -> bool:
1292+
if compute_kernel in {
1293+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1294+
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1295+
EmbeddingComputeKernel.KEY_VALUE.value,
1296+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1297+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1298+
}:
1299+
return True
1300+
return False
1301+
1302+
12961303
def _calculate_shard_io_sizes(
12971304
sharding_type: str,
12981305
batch_sizes: List[int],
@@ -1554,27 +1561,24 @@ def _calculate_storage_specific_sizes(
15541561
is_inference: bool = False,
15551562
clf: Optional[float] = None,
15561563
) -> List[int]:
1557-
tensor_sizes: List[int] = [
1558-
(
1559-
math.ceil(storage * prod(size) / prod(shape))
1560-
if sharding_type != ShardingType.DATA_PARALLEL.value
1561-
else storage
1562-
)
1563-
for size in shard_sizes
1564-
]
1565-
optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape)
1564+
tensor_sizes: List[int] = _calculate_tensor_sizes(
1565+
storage,
1566+
shape,
1567+
shard_sizes,
1568+
sharding_type,
1569+
)
1570+
optimizer_multipler: float = _get_optimizer_multipler(
1571+
optimizer_class,
1572+
shape,
1573+
)
15661574

15671575
optimizer_sizes: List[int] = [
15681576
math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes
15691577
]
15701578

1571-
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1572-
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1573-
# cache aux state (note that this is not the cache content itself)
1574-
cache_aux_state_sizes: List[int] = (
1575-
[0] * len(shard_sizes)
1576-
if clf is None
1577-
else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1579+
cache_aux_state_sizes: List[int] = _calculate_cache_aux_state_sizes(
1580+
shard_sizes,
1581+
clf,
15781582
)
15791583

15801584
return [
@@ -1589,6 +1593,30 @@ def _calculate_storage_specific_sizes(
15891593
]
15901594

15911595

1596+
def _calculate_tensor_sizes(
1597+
storage: int, shape: torch.Size, shard_sizes: List[List[int]], sharding_type: str
1598+
) -> List[int]:
1599+
return [
1600+
(
1601+
math.ceil(storage * prod(size) / prod(shape))
1602+
if sharding_type != ShardingType.DATA_PARALLEL.value
1603+
else storage
1604+
)
1605+
for size in shard_sizes
1606+
]
1607+
1608+
1609+
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1610+
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1611+
# cache aux state (note that this is not the cache content itself)
1612+
def _calculate_cache_aux_state_sizes(
1613+
shard_sizes: List[List[int]], clf: Optional[float]
1614+
) -> List[int]:
1615+
if clf is None:
1616+
return [0] * len(shard_sizes)
1617+
return [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1618+
1619+
15921620
def _get_optimizer_multipler(
15931621
optimizer_class: Optional[Type[torch.optim.Optimizer]],
15941622
shape: torch.Size,

0 commit comments

Comments
 (0)