@@ -1192,16 +1192,9 @@ def calculate_shard_storages(
1192
1192
hbm_storage : int = tensor_storage .get ("hbm" , 0 )
1193
1193
ddr_storage : int = tensor_storage .get ("ddr" , 0 )
1194
1194
1195
- table_cached : bool = False
1196
- if compute_kernel in {
1197
- EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1198
- EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1199
- EmbeddingComputeKernel .KEY_VALUE .value ,
1200
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1201
- EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1202
- }:
1195
+ table_cached = _is_table_cached (compute_kernel )
1196
+ if table_cached :
1203
1197
hbm_storage = round (ddr_storage * caching_ratio )
1204
- table_cached = True
1205
1198
1206
1199
optimizer_class = getattr (tensor , "_optimizer_classes" , [None ])[0 ]
1207
1200
@@ -1293,6 +1286,20 @@ def calculate_shard_storages(
1293
1286
]
1294
1287
1295
1288
1289
+ def _is_table_cached (
1290
+ compute_kernel : str ,
1291
+ ) -> bool :
1292
+ if compute_kernel in {
1293
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1294
+ EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1295
+ EmbeddingComputeKernel .KEY_VALUE .value ,
1296
+ EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1297
+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1298
+ }:
1299
+ return True
1300
+ return False
1301
+
1302
+
1296
1303
def _calculate_shard_io_sizes (
1297
1304
sharding_type : str ,
1298
1305
batch_sizes : List [int ],
@@ -1554,27 +1561,24 @@ def _calculate_storage_specific_sizes(
1554
1561
is_inference : bool = False ,
1555
1562
clf : Optional [float ] = None ,
1556
1563
) -> List [int ]:
1557
- tensor_sizes : List [int ] = [
1558
- (
1559
- math .ceil (storage * prod (size ) / prod (shape ))
1560
- if sharding_type != ShardingType .DATA_PARALLEL .value
1561
- else storage
1562
- )
1563
- for size in shard_sizes
1564
- ]
1565
- optimizer_multipler : float = _get_optimizer_multipler (optimizer_class , shape )
1564
+ tensor_sizes : List [int ] = _calculate_tensor_sizes (
1565
+ storage ,
1566
+ shape ,
1567
+ shard_sizes ,
1568
+ sharding_type ,
1569
+ )
1570
+ optimizer_multipler : float = _get_optimizer_multipler (
1571
+ optimizer_class ,
1572
+ shape ,
1573
+ )
1566
1574
1567
1575
optimizer_sizes : List [int ] = [
1568
1576
math .ceil (tensor_size * optimizer_multipler ) for tensor_size in tensor_sizes
1569
1577
]
1570
1578
1571
- # If a table has turned on UVM caching (meaning clf is not None), there'll be
1572
- # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1573
- # cache aux state (note that this is not the cache content itself)
1574
- cache_aux_state_sizes : List [int ] = (
1575
- [0 ] * len (shard_sizes )
1576
- if clf is None
1577
- else [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1579
+ cache_aux_state_sizes : List [int ] = _calculate_cache_aux_state_sizes (
1580
+ shard_sizes ,
1581
+ clf ,
1578
1582
)
1579
1583
1580
1584
return [
@@ -1589,6 +1593,30 @@ def _calculate_storage_specific_sizes(
1589
1593
]
1590
1594
1591
1595
1596
+ def _calculate_tensor_sizes (
1597
+ storage : int , shape : torch .Size , shard_sizes : List [List [int ]], sharding_type : str
1598
+ ) -> List [int ]:
1599
+ return [
1600
+ (
1601
+ math .ceil (storage * prod (size ) / prod (shape ))
1602
+ if sharding_type != ShardingType .DATA_PARALLEL .value
1603
+ else storage
1604
+ )
1605
+ for size in shard_sizes
1606
+ ]
1607
+
1608
+
1609
+ # If a table has turned on UVM caching (meaning clf is not None), there'll be
1610
+ # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1611
+ # cache aux state (note that this is not the cache content itself)
1612
+ def _calculate_cache_aux_state_sizes (
1613
+ shard_sizes : List [List [int ]], clf : Optional [float ]
1614
+ ) -> List [int ]:
1615
+ if clf is None :
1616
+ return [0 ] * len (shard_sizes )
1617
+ return [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1618
+
1619
+
1592
1620
def _get_optimizer_multipler (
1593
1621
optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1594
1622
shape : torch .Size ,
0 commit comments