Skip to content

Commit d100d78

Browse files
authored
Optimize KV cache distribution for asymmetric pipeline parallelism (#25164)
Signed-off-by: gholmes829 <[email protected]>
1 parent 7e4cd07 commit d100d78

File tree

5 files changed

+64
-38
lines changed

5 files changed

+64
-38
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,10 @@ def test_get_kv_cache_configs_multiple_workers():
681681
num_blocks=10,
682682
kv_cache_tensors=[
683683
KVCacheTensor(
684-
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"]
684+
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
685685
),
686686
KVCacheTensor(
687-
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer2"]
687+
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
688688
),
689689
],
690690
kv_cache_groups=[
@@ -718,7 +718,7 @@ def test_get_kv_cache_configs_multiple_workers():
718718
num_blocks=10,
719719
kv_cache_tensors=[
720720
KVCacheTensor(
721-
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"]
721+
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
722722
),
723723
],
724724
kv_cache_groups=[
@@ -802,7 +802,7 @@ def test_get_kv_cache_configs_multiple_workers():
802802
num_blocks=10,
803803
kv_cache_tensors=[
804804
KVCacheTensor(
805-
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"]
805+
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
806806
),
807807
],
808808
kv_cache_groups=[
@@ -813,7 +813,7 @@ def test_get_kv_cache_configs_multiple_workers():
813813
num_blocks=10,
814814
kv_cache_tensors=[
815815
KVCacheTensor(
816-
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"]
816+
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
817817
),
818818
],
819819
kv_cache_groups=[

vllm/config/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class CacheConfig:
124124
gpu_memory_utilization. However, users may want to manually specify
125125
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
126126
control of how much memory gets used when compared with using
127-
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
127+
gpu_memory_utilization. Note that kv_cache_memory_bytes
128128
(when not-None) ignores gpu_memory_utilization"""
129129

130130
def compute_hash(self) -> str:

vllm/entrypoints/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class LLM:
143143
size based on gpu_memory_utilization. However, users may want to
144144
manually specify the kv cache memory size. kv_cache_memory_bytes
145145
allows more fine-grain control of how much memory gets used when
146-
compared with using gpu_memory_memory_utilization. Note that
146+
compared with using gpu_memory_utilization. Note that
147147
kv_cache_memory_bytes (when not-None) ignores
148148
gpu_memory_utilization
149149
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.

vllm/v1/core/kv_cache_utils.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,35 +1113,12 @@ def get_kv_cache_config_from_groups(
11131113
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
11141114
)
11151115

1116-
kv_cache_config = KVCacheConfig(
1116+
return KVCacheConfig(
11171117
num_blocks=num_blocks,
11181118
kv_cache_tensors=kv_cache_tensors,
11191119
kv_cache_groups=kv_cache_groups,
11201120
)
11211121

1122-
min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups])
1123-
1124-
# Print the KV cache size and maximum concurrency.
1125-
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
1126-
if vllm_config.parallel_config.decode_context_parallel_size > 1:
1127-
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
1128-
logger.info(
1129-
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
1130-
vllm_config.parallel_config.decode_context_parallel_size,
1131-
)
1132-
num_tokens_str = f"{num_tokens:,}"
1133-
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
1134-
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
1135-
max_concurrency = get_max_concurrency_for_kv_cache_config(
1136-
vllm_config, kv_cache_config
1137-
)
1138-
logger.info(
1139-
"Maximum concurrency for %s tokens per request: %.2fx",
1140-
max_model_len_str,
1141-
max_concurrency,
1142-
)
1143-
return kv_cache_config
1144-
11451122

11461123
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
11471124
"""
@@ -1265,6 +1242,45 @@ def generate_scheduler_kv_cache_config(
12651242
return cfg
12661243

12671244

1245+
def _report_kv_cache_config(
1246+
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
1247+
) -> None:
1248+
"""
1249+
Log resolved KV cache configuration.
1250+
1251+
Args:
1252+
vllm_config: The global VllmConfig
1253+
kv_cache_config: The resolved KV cache configuration
1254+
"""
1255+
min_block_size = min(
1256+
[group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups]
1257+
)
1258+
1259+
# Log the KV cache size and maximum concurrency.
1260+
num_tokens = (
1261+
kv_cache_config.num_blocks
1262+
// len(kv_cache_config.kv_cache_groups)
1263+
* min_block_size
1264+
)
1265+
if vllm_config.parallel_config.decode_context_parallel_size > 1:
1266+
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
1267+
logger.info(
1268+
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
1269+
vllm_config.parallel_config.decode_context_parallel_size,
1270+
)
1271+
num_tokens_str = f"{num_tokens:,}"
1272+
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
1273+
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
1274+
max_concurrency = get_max_concurrency_for_kv_cache_config(
1275+
vllm_config, kv_cache_config
1276+
)
1277+
logger.info(
1278+
"Maximum concurrency for %s tokens per request: %.2fx",
1279+
max_model_len_str,
1280+
max_concurrency,
1281+
)
1282+
1283+
12681284
def get_kv_cache_configs(
12691285
vllm_config: VllmConfig,
12701286
kv_cache_specs: list[dict[str, KVCacheSpec]],
@@ -1284,7 +1300,8 @@ def get_kv_cache_configs(
12841300
3. Generate the KV cache configs for each worker based on the KV cache
12851301
grouping strategy. (This is reasonable because the layer ratio of
12861302
different PP stages are similar.)
1287-
4. Change the num_blocks of each worker to the smallest among all workers.
1303+
4. Change the num_blocks of each worker to the smallest among all workers
1304+
and shrink tensor sizes proportionally to avoid allocating unused memory.
12881305
12891306
Args:
12901307
vllm_config: The global VllmConfig
@@ -1345,13 +1362,22 @@ def get_kv_cache_configs(
13451362
)
13461363
)
13471364

1348-
# Change the num_blocks of each rank to the smallest among all ranks. We
1349-
# do not need to shrink the tensor size because it is valid to only use the
1350-
# first `num_blocks` blocks of the tensor.
1365+
# Change the num_blocks of each rank to the smallest among all ranks.
1366+
# We also need to shrink the tensor size proportionally to avoid
1367+
# allocating unused memory.
13511368
min_num_blocks = min(
13521369
kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs
13531370
)
13541371
for kv_cache_config in kv_cache_configs:
1372+
num_blocks_old = kv_cache_config.num_blocks
13551373
kv_cache_config.num_blocks = min_num_blocks
13561374

1375+
# Shrink tensor size proportionally
1376+
for tensor in kv_cache_config.kv_cache_tensors:
1377+
assert tensor.size % num_blocks_old == 0
1378+
tensor.size = tensor.size // num_blocks_old * min_num_blocks
1379+
1380+
if len(kv_cache_config.kv_cache_groups) > 0:
1381+
_report_kv_cache_config(vllm_config, kv_cache_config)
1382+
13571383
return kv_cache_configs

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,10 @@ def determine_available_memory(self) -> int:
253253
self.model_runner.profile_run()
254254

255255
msg = (
256-
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
257-
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
256+
f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
257+
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
258258
"KV Cache as specified by kv_cache_memory_bytes config and "
259-
"skipped memory profiling. This does does not respect the "
259+
"skipped memory profiling. This does not respect the "
260260
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
261261
"config when you want manual control of KV cache memory "
262262
"size. If OOM'ed, check the difference of initial free "

0 commit comments

Comments
 (0)