@@ -1113,35 +1113,12 @@ def get_kv_cache_config_from_groups(
1113
1113
KVCacheTensor (size = page_size * num_blocks , shared_by = shared_by )
1114
1114
)
1115
1115
1116
- kv_cache_config = KVCacheConfig (
1116
+ return KVCacheConfig (
1117
1117
num_blocks = num_blocks ,
1118
1118
kv_cache_tensors = kv_cache_tensors ,
1119
1119
kv_cache_groups = kv_cache_groups ,
1120
1120
)
1121
1121
1122
- min_block_size = min ([group .kv_cache_spec .block_size for group in kv_cache_groups ])
1123
-
1124
- # Print the KV cache size and maximum concurrency.
1125
- num_tokens = num_blocks // len (kv_cache_groups ) * min_block_size
1126
- if vllm_config .parallel_config .decode_context_parallel_size > 1 :
1127
- num_tokens *= vllm_config .parallel_config .decode_context_parallel_size
1128
- logger .info (
1129
- "Multiplying the GPU KV cache size by the dcp_world_size %d." ,
1130
- vllm_config .parallel_config .decode_context_parallel_size ,
1131
- )
1132
- num_tokens_str = f"{ num_tokens :,} "
1133
- logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
1134
- max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
1135
- max_concurrency = get_max_concurrency_for_kv_cache_config (
1136
- vllm_config , kv_cache_config
1137
- )
1138
- logger .info (
1139
- "Maximum concurrency for %s tokens per request: %.2fx" ,
1140
- max_model_len_str ,
1141
- max_concurrency ,
1142
- )
1143
- return kv_cache_config
1144
-
1145
1122
1146
1123
def unify_hybrid_kv_cache_specs (kv_cache_spec : dict [str , KVCacheSpec ]):
1147
1124
"""
@@ -1265,6 +1242,45 @@ def generate_scheduler_kv_cache_config(
1265
1242
return cfg
1266
1243
1267
1244
1245
+ def _report_kv_cache_config (
1246
+ vllm_config : VllmConfig , kv_cache_config : KVCacheConfig
1247
+ ) -> None :
1248
+ """
1249
+ Log resolved KV cache configuration.
1250
+
1251
+ Args:
1252
+ vllm_config: The global VllmConfig
1253
+ kv_cache_config: The resolved KV cache configuration
1254
+ """
1255
+ min_block_size = min (
1256
+ [group .kv_cache_spec .block_size for group in kv_cache_config .kv_cache_groups ]
1257
+ )
1258
+
1259
+ # Log the KV cache size and maximum concurrency.
1260
+ num_tokens = (
1261
+ kv_cache_config .num_blocks
1262
+ // len (kv_cache_config .kv_cache_groups )
1263
+ * min_block_size
1264
+ )
1265
+ if vllm_config .parallel_config .decode_context_parallel_size > 1 :
1266
+ num_tokens *= vllm_config .parallel_config .decode_context_parallel_size
1267
+ logger .info (
1268
+ "Multiplying the GPU KV cache size by the dcp_world_size %d." ,
1269
+ vllm_config .parallel_config .decode_context_parallel_size ,
1270
+ )
1271
+ num_tokens_str = f"{ num_tokens :,} "
1272
+ logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
1273
+ max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
1274
+ max_concurrency = get_max_concurrency_for_kv_cache_config (
1275
+ vllm_config , kv_cache_config
1276
+ )
1277
+ logger .info (
1278
+ "Maximum concurrency for %s tokens per request: %.2fx" ,
1279
+ max_model_len_str ,
1280
+ max_concurrency ,
1281
+ )
1282
+
1283
+
1268
1284
def get_kv_cache_configs (
1269
1285
vllm_config : VllmConfig ,
1270
1286
kv_cache_specs : list [dict [str , KVCacheSpec ]],
@@ -1284,7 +1300,8 @@ def get_kv_cache_configs(
1284
1300
3. Generate the KV cache configs for each worker based on the KV cache
1285
1301
grouping strategy. (This is reasonable because the layer ratio of
1286
1302
different PP stages are similar.)
1287
- 4. Change the num_blocks of each worker to the smallest among all workers.
1303
+ 4. Change the num_blocks of each worker to the smallest among all workers
1304
+ and shrink tensor sizes proportionally to avoid allocating unused memory.
1288
1305
1289
1306
Args:
1290
1307
vllm_config: The global VllmConfig
@@ -1345,13 +1362,22 @@ def get_kv_cache_configs(
1345
1362
)
1346
1363
)
1347
1364
1348
- # Change the num_blocks of each rank to the smallest among all ranks. We
1349
- # do not need to shrink the tensor size because it is valid to only use the
1350
- # first `num_blocks` blocks of the tensor .
1365
+ # Change the num_blocks of each rank to the smallest among all ranks.
1366
+ # We also need to shrink the tensor size proportionally to avoid
1367
+ # allocating unused memory .
1351
1368
min_num_blocks = min (
1352
1369
kv_cache_config .num_blocks for kv_cache_config in kv_cache_configs
1353
1370
)
1354
1371
for kv_cache_config in kv_cache_configs :
1372
+ num_blocks_old = kv_cache_config .num_blocks
1355
1373
kv_cache_config .num_blocks = min_num_blocks
1356
1374
1375
+ # Shrink tensor size proportionally
1376
+ for tensor in kv_cache_config .kv_cache_tensors :
1377
+ assert tensor .size % num_blocks_old == 0
1378
+ tensor .size = tensor .size // num_blocks_old * min_num_blocks
1379
+
1380
+ if len (kv_cache_config .kv_cache_groups ) > 0 :
1381
+ _report_kv_cache_config (vllm_config , kv_cache_config )
1382
+
1357
1383
return kv_cache_configs
0 commit comments