Skip to content

Commit 49afe4a

Browse files
committed
fix: use kv_transfer_params instead of disagg_prefill_resp
- Add kv_transfer_params to prefill request to enable disaggregated mode - Extract kv_transfer_params from prefill response and forward to decode - Set remote_host to prefill endpoint for KV cache retrieval
1 parent 2e4d2d2 commit 49afe4a

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

src/vllm_router/routers/routing_logic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,15 @@ class DisaggregatedPrefillOrchestratedRouter(RoutingInterface):
523523
Unlike DisaggregatedPrefillRouter (which requires 2 separate client requests),
524524
this router handles the entire flow internally:
525525
1. Receives request from client
526-
2. Forwards to Prefill endpoint
527-
3. Gets prefill response with KV cache metadata
528-
4. Adds disagg_prefill_resp to request and forwards to Decode
526+
2. Forwards to Prefill endpoint with kv_transfer_params to enable disaggregated mode
527+
3. Gets prefill response with kv_transfer_params containing KV cache metadata
528+
4. Extracts kv_transfer_params, sets remote_host, and forwards to Decode
529529
5. Streams decode response back to client
530530
531531
This is designed for NxDI (Neuronx Distributed Inference) on AWS Trainium,
532-
similar to NxDI's toy_proxy_server.py pattern.
532+
following NxDI's toy_proxy_server.py pattern.
533+
534+
Reference: NxDI/examples/vllm/disaggregated_inference/toy_proxy_server.py
533535
534536
Load balancing: Uses round-robin across available prefill and decode pods.
535537
"""

src/vllm_router/services/request_service/request.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,14 @@ async def route_orchestrated_disaggregated_request(
384384
"""
385385
Orchestrated disaggregated inference following NxDI's toy_proxy_server pattern.
386386
387-
Flow:
388-
1. Send request to Prefill endpoint (with max_tokens=1)
389-
2. Get response with KV cache metadata
390-
3. Add disagg_prefill_resp to request
391-
4. Send to Decode endpoint
392-
5. Stream response back to client
387+
Flow (matches NxDI toy_proxy_server.py):
388+
1. Send request to Prefill endpoint with kv_transfer_params and max_tokens=1
389+
2. Get response containing kv_transfer_params with KV cache metadata
390+
3. Extract kv_transfer_params, set remote_host to prefill endpoint
391+
4. Forward kv_transfer_params to Decode endpoint
392+
5. Stream decode response back to client
393+
394+
Reference: NxDI/examples/vllm/disaggregated_inference/toy_proxy_server.py
393395
"""
394396
in_router_time = time.time()
395397
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
@@ -427,10 +429,25 @@ async def route_orchestrated_disaggregated_request(
427429
logger.info(f"[{request_id}] Sending prefill request to {prefill_api_url}")
428430

429431
# Create prefill request with max_tokens=1 to optimize prefill step
432+
# Also add kv_transfer_params to enable disaggregated mode on prefill
433+
# Reference: NxDI toy_proxy_server.py
430434
prefill_request_json = request_json.copy()
431435
prefill_request_json["max_tokens"] = 1
432436
if "max_completion_tokens" in prefill_request_json:
433437
prefill_request_json["max_completion_tokens"] = 1
438+
# Enable disaggregated inference mode - prefill will return kv_transfer_params
439+
prefill_request_json["kv_transfer_params"] = {
440+
"do_remote_decode": True,
441+
"do_remote_prefill": False,
442+
"remote_engine_id": None,
443+
"remote_block_ids": None,
444+
"remote_host": None,
445+
"remote_port": None
446+
}
447+
# Disable streaming for prefill to get full response with kv_transfer_params
448+
prefill_request_json["stream"] = False
449+
if "stream_options" in prefill_request_json:
450+
del prefill_request_json["stream_options"]
434451

435452
st = time.time()
436453
is_streaming = request_json.get("stream", False)
@@ -463,9 +480,17 @@ async def route_orchestrated_disaggregated_request(
463480
logger.info(f"[{request_id}] Prefill completed in {et - st:.4f}s (TTFT)")
464481
logger.debug(f"[{request_id}] Prefill response keys: {prefill_data.keys() if isinstance(prefill_data, dict) else 'not a dict'}")
465482

466-
# Step 2: Add prefill metadata and send to Decode
483+
# Step 2: Extract kv_transfer_params and send to Decode
484+
# kv_transfer_params is the vLLM/NxDI-supported field for KV cache handoff
485+
# Reference: NxDI toy_proxy_server.py
467486
decode_request = request_json.copy()
468-
decode_request["disagg_prefill_resp"] = prefill_data
487+
kv_transfer_params = prefill_data.get("kv_transfer_params", {})
488+
if kv_transfer_params:
489+
# Set remote_host to prefill endpoint for KV cache retrieval
490+
kv_transfer_params["remote_host"] = prefill_url.split("://")[1].split(":")[0]
491+
decode_request["kv_transfer_params"] = kv_transfer_params
492+
else:
493+
logger.warning(f"[{request_id}] Prefill response did not contain kv_transfer_params")
469494

470495
decode_api_url = f"{decode_url}{endpoint}"
471496
logger.info(f"[{request_id}] Sending decode request to {decode_api_url}")

0 commit comments

Comments
 (0)