Skip to content

Commit 7fdfa3c

Browse files
Ahmed Shuaibifacebook-github-bot
authored andcommitted
refactor: create functions for shard/tensor size calculations (#3257)
Summary: - refactor to create function for checking if table is cached - refactor to create functions for tensor size calculations Reviewed By: aporialiao Differential Revision: D79007077
1 parent 3bdf9f3 commit 7fdfa3c

File tree

1 file changed

+69
-30
lines changed

1 file changed

+69
-30
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,16 +1192,9 @@ def calculate_shard_storages(
11921192
hbm_storage: int = tensor_storage.get("hbm", 0)
11931193
ddr_storage: int = tensor_storage.get("ddr", 0)
11941194

1195-
table_cached: bool = False
1196-
if compute_kernel in {
1197-
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1198-
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1199-
EmbeddingComputeKernel.KEY_VALUE.value,
1200-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1201-
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1202-
}:
1195+
table_cached = _is_table_cached(compute_kernel)
1196+
if table_cached:
12031197
hbm_storage = round(ddr_storage * caching_ratio)
1204-
table_cached = True
12051198

12061199
optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0]
12071200

@@ -1293,6 +1286,20 @@ def calculate_shard_storages(
12931286
]
12941287

12951288

1289+
def _is_table_cached(
1290+
compute_kernel: str,
1291+
) -> bool:
1292+
if compute_kernel in {
1293+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1294+
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1295+
EmbeddingComputeKernel.KEY_VALUE.value,
1296+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1297+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1298+
}:
1299+
return True
1300+
return False
1301+
1302+
12961303
def _calculate_shard_io_sizes(
12971304
sharding_type: str,
12981305
batch_sizes: List[int],
@@ -1554,27 +1561,20 @@ def _calculate_storage_specific_sizes(
15541561
is_inference: bool = False,
15551562
clf: Optional[float] = None,
15561563
) -> List[int]:
1557-
tensor_sizes: List[int] = [
1558-
(
1559-
math.ceil(storage * prod(size) / prod(shape))
1560-
if sharding_type != ShardingType.DATA_PARALLEL.value
1561-
else storage
1562-
)
1563-
for size in shard_sizes
1564-
]
1565-
optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape)
1566-
1567-
optimizer_sizes: List[int] = [
1568-
math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes
1569-
]
1570-
1571-
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1572-
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1573-
# cache aux state (note that this is not the cache content itself)
1574-
cache_aux_state_sizes: List[int] = (
1575-
[0] * len(shard_sizes)
1576-
if clf is None
1577-
else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1564+
tensor_sizes: List[int] = _calculate_tensor_sizes(
1565+
storage,
1566+
shape,
1567+
shard_sizes,
1568+
sharding_type,
1569+
)
1570+
optimizer_sizes = _calculate_optimizer_sizes(
1571+
tensor_sizes,
1572+
optimizer_class,
1573+
shape,
1574+
)
1575+
cache_aux_state_sizes: List[int] = _calculate_cache_aux_state_sizes(
1576+
shard_sizes,
1577+
clf,
15781578
)
15791579

15801580
return [
@@ -1589,6 +1589,45 @@ def _calculate_storage_specific_sizes(
15891589
]
15901590

15911591

1592+
def _calculate_tensor_sizes(
1593+
storage: int, shape: torch.Size, shard_sizes: List[List[int]], sharding_type: str
1594+
) -> List[int]:
1595+
return [
1596+
(
1597+
math.ceil(storage * prod(size) / prod(shape))
1598+
if sharding_type != ShardingType.DATA_PARALLEL.value
1599+
else storage
1600+
)
1601+
for size in shard_sizes
1602+
]
1603+
1604+
1605+
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1606+
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1607+
# cache aux state (note that this is not the cache content itself)
1608+
def _calculate_cache_aux_state_sizes(
1609+
shard_sizes: List[List[int]], clf: Optional[float]
1610+
) -> List[int]:
1611+
if clf is None:
1612+
return [0] * len(shard_sizes)
1613+
return [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1614+
1615+
1616+
def _calculate_optimizer_sizes(
1617+
tensor_sizes: List[int],
1618+
optimizer_class: Optional[Type[torch.optim.Optimizer]],
1619+
sharding_tensor_shape: torch.Size,
1620+
) -> List[int]:
1621+
optimizer_multiplier: float = _get_optimizer_multipler(
1622+
optimizer_class,
1623+
sharding_tensor_shape,
1624+
)
1625+
optimizer_sizes: List[int] = [
1626+
math.ceil(tensor_size * optimizer_multiplier) for tensor_size in tensor_sizes
1627+
]
1628+
return optimizer_sizes
1629+
1630+
15921631
def _get_optimizer_multipler(
15931632
optimizer_class: Optional[Type[torch.optim.Optimizer]],
15941633
shape: torch.Size,

0 commit comments

Comments
 (0)