Skip to content

Commit 9a9f48d

Browse files
sdavidbdDavid Ben-David
andauthored
[V1] [P/D] Add Support for KV Load Failure Recovery (#19330)
Signed-off-by: David Ben-David <[email protected]> Co-authored-by: David Ben-David <[email protected]>
1 parent 67f3fb0 commit 9a9f48d

24 files changed

+1035
-82
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# KV Load Failure Recovery Test
2+
3+
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.
4+
5+
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.
6+
7+
## Files
8+
9+
- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
10+
- `decode_example.py` – performs the decode stage. Accepts:
11+
- `--simulate-failure`: simulates KV load failure using a custom connector.
12+
- `--async-load`: enables asynchronous KV loading mode.
13+
- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
14+
- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages:
15+
1. Normal decode (baseline).
16+
2. Decode with simulated sync KV load failure.
17+
3. Decode with simulated async KV load failure.
18+
19+
Finally, it compares the output of the baseline with the recovered outputs to verify correctness.
20+
21+
## How It Works
22+
23+
- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
24+
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
25+
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
26+
27+
## Usage
28+
29+
```bash
30+
./run.sh
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import argparse
4+
5+
from vllm import LLM, SamplingParams
6+
from vllm.config import KVTransferConfig
7+
8+
9+
def read_prompts():
10+
"""Read prompts from prefill_output.txt"""
11+
prompts = []
12+
try:
13+
with open("prefill_output.txt") as f:
14+
for line in f:
15+
prompts.append(line.strip())
16+
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
17+
return prompts
18+
except FileNotFoundError:
19+
print("Error: prefill_output.txt file not found")
20+
exit(-1)
21+
22+
23+
def main():
24+
prompts = read_prompts()
25+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
26+
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument(
29+
"--simulate-failure", action="store_true", help="Simulate KV load failure."
30+
)
31+
parser.add_argument(
32+
"--async-load", action="store_true", help="Simulate async KV load"
33+
)
34+
args = parser.parse_args()
35+
36+
if args.simulate_failure:
37+
ktc = KVTransferConfig(
38+
kv_connector="RogueSharedStorageConnector",
39+
kv_role="kv_both",
40+
kv_connector_extra_config={
41+
"shared_storage_path": "local_storage",
42+
"async_load": args.async_load,
43+
},
44+
kv_connector_module_path="rogue_shared_storage_connector",
45+
)
46+
out_file = (
47+
"async_decode_recovered_output.txt"
48+
if args.async_load
49+
else "sync_decode_recovered_output.txt"
50+
)
51+
else:
52+
ktc = KVTransferConfig(
53+
kv_connector="SharedStorageConnector",
54+
kv_role="kv_both",
55+
kv_connector_extra_config={
56+
"shared_storage_path": "local_storage",
57+
},
58+
)
59+
out_file = "decode_output.txt"
60+
61+
llm = LLM(
62+
model="meta-llama/Llama-3.2-1B-Instruct",
63+
enforce_eager=True,
64+
gpu_memory_utilization=0.8,
65+
max_num_batched_tokens=64,
66+
max_num_seqs=16,
67+
kv_transfer_config=ktc,
68+
)
69+
70+
outputs = llm.generate(prompts, sampling_params)
71+
72+
sep_str = "-" * 30
73+
with open(out_file, "w", encoding="utf-8") as f:
74+
for output in outputs:
75+
prompt = output.prompt
76+
generated_text = output.outputs[0].text
77+
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
78+
print(out_str)
79+
print(sep_str)
80+
f.write(out_str)
81+
f.write(sep_str)
82+
83+
84+
if __name__ == "__main__":
85+
main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import LLM, SamplingParams
5+
from vllm.config import KVTransferConfig
6+
7+
8+
def read_prompts():
9+
context = "Hi " * 1000
10+
context2 = "Hey " * 500
11+
return [
12+
context + "Hello, my name is",
13+
context + "The capital of France is",
14+
context2 + "Your name is",
15+
context2 + "The capital of China is",
16+
]
17+
18+
19+
def main():
20+
prompts = read_prompts()
21+
22+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
23+
24+
llm = LLM(
25+
model="meta-llama/Llama-3.2-1B-Instruct",
26+
enforce_eager=True,
27+
gpu_memory_utilization=0.8,
28+
kv_transfer_config=KVTransferConfig(
29+
kv_connector="SharedStorageConnector",
30+
kv_role="kv_both",
31+
kv_connector_extra_config={"shared_storage_path": "local_storage"},
32+
),
33+
) # , max_model_len=2048, max_num_batched_tokens=2048)
34+
35+
# 1ST generation (prefill instance)
36+
outputs = llm.generate(
37+
prompts,
38+
sampling_params,
39+
)
40+
41+
new_prompts = []
42+
print("-" * 30)
43+
for output in outputs:
44+
prompt = output.prompt
45+
generated_text = output.outputs[0].text
46+
new_prompts.append(prompt + generated_text)
47+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
48+
print("-" * 30)
49+
50+
# Write new_prompts to prefill_output.txt
51+
with open("prefill_output.txt", "w") as f:
52+
for prompt in new_prompts:
53+
f.write(prompt + "\n")
54+
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
55+
56+
57+
if __name__ == "__main__":
58+
main()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
import logging
5+
from dataclasses import dataclass, field
6+
from typing import TYPE_CHECKING, Optional
7+
8+
from vllm.config import VllmConfig
9+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
10+
KVConnectorMetadata,
11+
KVConnectorRole,
12+
)
13+
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
14+
SharedStorageConnector,
15+
SharedStorageConnectorMetadata,
16+
)
17+
from vllm.forward_context import ForwardContext
18+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
19+
from vllm.v1.request import Request
20+
21+
if TYPE_CHECKING:
22+
from vllm.v1.core.sched.output import SchedulerOutput
23+
24+
logger = logging.getLogger()
25+
logging.basicConfig(level=logging.INFO)
26+
27+
28+
@dataclass
29+
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata):
30+
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
31+
32+
@classmethod
33+
def from_base(cls, base: SharedStorageConnectorMetadata):
34+
return cls(requests=base.requests)
35+
36+
37+
class RogueSharedStorageConnector(SharedStorageConnector):
38+
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
39+
super().__init__(vllm_config=vllm_config, role=role)
40+
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
41+
"async_load", False
42+
)
43+
self._invalid_block_ids: set = None
44+
self._seen_requests: set = set()
45+
self._req_to_block_ids: dict[str, list[int]] = dict()
46+
47+
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
48+
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata)
49+
index, failed_request = next(
50+
(
51+
(i, x)
52+
for i, x in enumerate(connector_metadata.requests)
53+
if not x.is_store
54+
),
55+
(None, None),
56+
)
57+
if index is not None:
58+
del connector_metadata.requests[index]
59+
self._invalid_block_ids = set(
60+
(
61+
failed_request.slot_mapping[:: self._block_size] // self._block_size
62+
).tolist()
63+
)
64+
logger.info(
65+
"Simulating failure to load all KV blocks for the "
66+
"first load request. Total blocks: %d",
67+
len(self._invalid_block_ids),
68+
)
69+
super().bind_connector_metadata(connector_metadata)
70+
71+
def clear_connector_metadata(self) -> None:
72+
self._invalid_block_ids = None
73+
super().clear_connector_metadata()
74+
75+
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
76+
if self._async_load and forward_context.attn_metadata is None:
77+
# Bypass sanity check in super().start_load_kv
78+
forward_context.attn_metadata = "None"
79+
80+
super().start_load_kv(forward_context, **kwargs)
81+
82+
def get_finished(
83+
self, finished_req_ids: set[str]
84+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
85+
if self._async_load:
86+
meta = self._get_connector_metadata()
87+
assert isinstance(meta, RogueSharedStorageConnectorMetadata)
88+
if meta.req_to_block_ids:
89+
return None, set(meta.req_to_block_ids)
90+
91+
return None, None
92+
93+
def get_block_ids_with_load_errors(self) -> set[int]:
94+
return self._invalid_block_ids
95+
96+
def get_num_new_matched_tokens(
97+
self,
98+
request: Request,
99+
num_computed_tokens: int,
100+
) -> tuple[int, bool]:
101+
if request.request_id in self._seen_requests:
102+
return 0, False
103+
104+
self._seen_requests.add(request.request_id)
105+
106+
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
107+
return num_tokens, self._async_load and num_tokens > 0
108+
109+
def update_state_after_alloc(
110+
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
111+
):
112+
"""
113+
Update KVConnector state after block allocation.
114+
115+
If blocks were allocated, add to _requests_need_load,
116+
such that we load the KVs in the next forward pass.
117+
"""
118+
super().update_state_after_alloc(request, blocks, num_external_tokens)
119+
120+
if num_external_tokens > 0:
121+
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
122+
123+
def build_connector_meta(
124+
self,
125+
scheduler_output: "SchedulerOutput",
126+
) -> KVConnectorMetadata:
127+
if not self._async_load:
128+
base = super().build_connector_meta(scheduler_output)
129+
meta = RogueSharedStorageConnectorMetadata.from_base(base)
130+
else:
131+
meta = RogueSharedStorageConnectorMetadata()
132+
if self._requests_need_load:
133+
for req_id, request in self._requests_need_load.items():
134+
meta.add_request(
135+
token_ids=request.prompt_token_ids,
136+
block_ids=self._req_to_block_ids[req_id],
137+
block_size=self._block_size,
138+
is_store=False,
139+
mm_hashes=[],
140+
)
141+
# Clear state
142+
self._requests_need_load.clear()
143+
meta.req_to_block_ids = self._req_to_block_ids
144+
self._req_to_block_ids = dict()
145+
return meta
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
3+
# Constants
4+
SHARED_STORAGE_DIR="local_storage"
5+
PREFILL_OUTPUT="prefill_output.txt"
6+
DECODE_OUTPUT="decode_output.txt"
7+
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
8+
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
9+
10+
# Cleanup
11+
rm -rf "$SHARED_STORAGE_DIR"
12+
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
13+
14+
# Run inference examples
15+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
16+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
17+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
18+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
19+
20+
# Compare outputs
21+
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
22+
echo "❌ Outputs differ: sync recovery failed."
23+
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
24+
exit 1
25+
fi
26+
27+
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
28+
echo "❌ Outputs differ: async recovery failed."
29+
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
30+
exit 1
31+
fi
32+
33+
echo "✅ Outputs match: recovery successful."

0 commit comments

Comments
 (0)