@@ -384,12 +384,14 @@ async def route_orchestrated_disaggregated_request(
384384 """
385385 Orchestrated disaggregated inference following NxDI's toy_proxy_server pattern.
386386
387- Flow:
388- 1. Send request to Prefill endpoint (with max_tokens=1)
389- 2. Get response with KV cache metadata
390- 3. Add disagg_prefill_resp to request
391- 4. Send to Decode endpoint
392- 5. Stream response back to client
387+ Flow (matches NxDI toy_proxy_server.py):
388+ 1. Send request to Prefill endpoint with kv_transfer_params and max_tokens=1
389+ 2. Get response containing kv_transfer_params with KV cache metadata
390+ 3. Extract kv_transfer_params, set remote_host to prefill endpoint
391+ 4. Forward kv_transfer_params to Decode endpoint
392+ 5. Stream decode response back to client
393+
394+ Reference: NxDI/examples/vllm/disaggregated_inference/toy_proxy_server.py
393395 """
394396 in_router_time = time .time ()
395397 request_id = request .headers .get ("X-Request-Id" ) or str (uuid .uuid4 ())
@@ -427,10 +429,25 @@ async def route_orchestrated_disaggregated_request(
427429 logger .info (f"[{ request_id } ] Sending prefill request to { prefill_api_url } " )
428430
429431 # Create prefill request with max_tokens=1 to optimize prefill step
432+ # Also add kv_transfer_params to enable disaggregated mode on prefill
433+ # Reference: NxDI toy_proxy_server.py
430434 prefill_request_json = request_json .copy ()
431435 prefill_request_json ["max_tokens" ] = 1
432436 if "max_completion_tokens" in prefill_request_json :
433437 prefill_request_json ["max_completion_tokens" ] = 1
438+ # Enable disaggregated inference mode - prefill will return kv_transfer_params
439+ prefill_request_json ["kv_transfer_params" ] = {
440+ "do_remote_decode" : True ,
441+ "do_remote_prefill" : False ,
442+ "remote_engine_id" : None ,
443+ "remote_block_ids" : None ,
444+ "remote_host" : None ,
445+ "remote_port" : None
446+ }
447+ # Disable streaming for prefill to get full response with kv_transfer_params
448+ prefill_request_json ["stream" ] = False
449+ if "stream_options" in prefill_request_json :
450+ del prefill_request_json ["stream_options" ]
434451
435452 st = time .time ()
436453 is_streaming = request_json .get ("stream" , False )
@@ -463,9 +480,17 @@ async def route_orchestrated_disaggregated_request(
463480 logger .info (f"[{ request_id } ] Prefill completed in { et - st :.4f} s (TTFT)" )
464481 logger .debug (f"[{ request_id } ] Prefill response keys: { prefill_data .keys () if isinstance (prefill_data , dict ) else 'not a dict' } " )
465482
466- # Step 2: Add prefill metadata and send to Decode
483+ # Step 2: Extract kv_transfer_params and send to Decode
484+ # kv_transfer_params is the vLLM/NxDI-supported field for KV cache handoff
485+ # Reference: NxDI toy_proxy_server.py
467486 decode_request = request_json .copy ()
468- decode_request ["disagg_prefill_resp" ] = prefill_data
487+ kv_transfer_params = prefill_data .get ("kv_transfer_params" , {})
488+ if kv_transfer_params :
489+ # Set remote_host to prefill endpoint for KV cache retrieval
490+ kv_transfer_params ["remote_host" ] = prefill_url .split ("://" )[1 ].split (":" )[0 ]
491+ decode_request ["kv_transfer_params" ] = kv_transfer_params
492+ else :
493+ logger .warning (f"[{ request_id } ] Prefill response did not contain kv_transfer_params" )
469494
470495 decode_api_url = f"{ decode_url } { endpoint } "
471496 logger .info (f"[{ request_id } ] Sending decode request to { decode_api_url } " )
0 commit comments