@@ -1192,16 +1192,9 @@ def calculate_shard_storages(
1192
1192
hbm_storage : int = tensor_storage .get ("hbm" , 0 )
1193
1193
ddr_storage : int = tensor_storage .get ("ddr" , 0 )
1194
1194
1195
- table_cached : bool = False
1196
- if compute_kernel in {
1197
- EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1198
- EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1199
- EmbeddingComputeKernel .KEY_VALUE .value ,
1200
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1201
- EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1202
- }:
1195
+ table_cached = _is_table_cached (compute_kernel )
1196
+ if table_cached :
1203
1197
hbm_storage = round (ddr_storage * caching_ratio )
1204
- table_cached = True
1205
1198
1206
1199
optimizer_class = getattr (tensor , "_optimizer_classes" , [None ])[0 ]
1207
1200
@@ -1293,6 +1286,20 @@ def calculate_shard_storages(
1293
1286
]
1294
1287
1295
1288
1289
+ def _is_table_cached (
1290
+ compute_kernel : str ,
1291
+ ) -> bool :
1292
+ if compute_kernel in {
1293
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1294
+ EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1295
+ EmbeddingComputeKernel .KEY_VALUE .value ,
1296
+ EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1297
+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1298
+ }:
1299
+ return True
1300
+ return False
1301
+
1302
+
1296
1303
def _calculate_shard_io_sizes (
1297
1304
sharding_type : str ,
1298
1305
batch_sizes : List [int ],
@@ -1554,27 +1561,20 @@ def _calculate_storage_specific_sizes(
1554
1561
is_inference : bool = False ,
1555
1562
clf : Optional [float ] = None ,
1556
1563
) -> List [int ]:
1557
- tensor_sizes : List [int ] = [
1558
- (
1559
- math .ceil (storage * prod (size ) / prod (shape ))
1560
- if sharding_type != ShardingType .DATA_PARALLEL .value
1561
- else storage
1562
- )
1563
- for size in shard_sizes
1564
- ]
1565
- optimizer_multipler : float = _get_optimizer_multipler (optimizer_class , shape )
1566
-
1567
- optimizer_sizes : List [int ] = [
1568
- math .ceil (tensor_size * optimizer_multipler ) for tensor_size in tensor_sizes
1569
- ]
1570
-
1571
- # If a table has turned on UVM caching (meaning clf is not None), there'll be
1572
- # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1573
- # cache aux state (note that this is not the cache content itself)
1574
- cache_aux_state_sizes : List [int ] = (
1575
- [0 ] * len (shard_sizes )
1576
- if clf is None
1577
- else [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1564
+ tensor_sizes : List [int ] = _calculate_tensor_sizes (
1565
+ storage ,
1566
+ shape ,
1567
+ shard_sizes ,
1568
+ sharding_type ,
1569
+ )
1570
+ optimizer_sizes = _calculate_optimizer_sizes (
1571
+ tensor_sizes ,
1572
+ optimizer_class ,
1573
+ shape ,
1574
+ )
1575
+ cache_aux_state_sizes : List [int ] = _calculate_cache_aux_state_sizes (
1576
+ shard_sizes ,
1577
+ clf ,
1578
1578
)
1579
1579
1580
1580
return [
@@ -1589,6 +1589,45 @@ def _calculate_storage_specific_sizes(
1589
1589
]
1590
1590
1591
1591
1592
+ def _calculate_tensor_sizes (
1593
+ storage : int , shape : torch .Size , shard_sizes : List [List [int ]], sharding_type : str
1594
+ ) -> List [int ]:
1595
+ return [
1596
+ (
1597
+ math .ceil (storage * prod (size ) / prod (shape ))
1598
+ if sharding_type != ShardingType .DATA_PARALLEL .value
1599
+ else storage
1600
+ )
1601
+ for size in shard_sizes
1602
+ ]
1603
+
1604
+
1605
+ # If a table has turned on UVM caching (meaning clf is not None), there'll be
1606
+ # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1607
+ # cache aux state (note that this is not the cache content itself)
1608
+ def _calculate_cache_aux_state_sizes (
1609
+ shard_sizes : List [List [int ]], clf : Optional [float ]
1610
+ ) -> List [int ]:
1611
+ if clf is None :
1612
+ return [0 ] * len (shard_sizes )
1613
+ return [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1614
+
1615
+
1616
+ def _calculate_optimizer_sizes (
1617
+ tensor_sizes : List [int ],
1618
+ optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1619
+ sharding_tensor_shape : torch .Size ,
1620
+ ) -> List [int ]:
1621
+ optimizer_multiplier : float = _get_optimizer_multipler (
1622
+ optimizer_class ,
1623
+ sharding_tensor_shape ,
1624
+ )
1625
+ optimizer_sizes : List [int ] = [
1626
+ math .ceil (tensor_size * optimizer_multiplier ) for tensor_size in tensor_sizes
1627
+ ]
1628
+ return optimizer_sizes
1629
+
1630
+
1592
1631
def _get_optimizer_multipler (
1593
1632
optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1594
1633
shape : torch .Size ,
0 commit comments