diff --git a/proposals/disaggregated-prefill-orchestrated-routing.md b/proposals/disaggregated-prefill-orchestrated-routing.md new file mode 100644 index 000000000..72f0eed08 --- /dev/null +++ b/proposals/disaggregated-prefill-orchestrated-routing.md @@ -0,0 +1,168 @@ +# Disaggregated Prefill Orchestrated Routing + +## Table of Contents + +- [Summary](#summary) +- [Motivation](#motivation) +- [Proposal](#proposal) + +## Summary + +This proposal adds a new routing algorithm `disaggregated_prefill_orchestrated` to the vLLM Production Stack router. This enables prefill/decode disaggregation where the router orchestrates the request flow between dedicated prefill and decode pods, forwarding KV cache transfer metadata between them. This complements LMCache-based disaggregated inference by supporting backends with custom `kv_connector` implementations (e.g., NIXL, NCCL). + +## Motivation + +Disaggregated inference separates compute-heavy prefill from memory-bound decode phases. This architectural pattern is increasingly important for: + +- **Independent scaling** - Prefill and decode pods can scale based on different metrics (prompt throughput vs. generation throughput) +- **Heterogeneous hardware** - Prefill and decode can run on different hardware profiles optimized for their workloads +- **Better resource utilization** - Under high concurrency, avoiding co-located P/D reduces resource contention + +### Goals + +- Add `disaggregated_prefill_orchestrated` as a new routing logic option +- Enable router to identify and route to prefill vs. decode pods via labels +- Orchestrate the P→D request flow, extracting and forwarding KV transfer metadata +- Leverage existing K8s service discovery infrastructure +- Support streaming responses from decode phase + +### Non-Goals + +- Modifying LMCache-based disaggregated inference +- Implementing the underlying KV cache transfer mechanism (handled by vLLM backends) +- Autoscaling logic (handled by KEDA with vLLM metrics) +- Supporting non-Kubernetes deployments in this initial implementation + +## Proposal + +### Two Disaggregated Inference Approaches + +| Approach | KV Transfer | Router Role | Use Case | +|----------|-------------|-------------|----------| +| **LMCache-based DI** | LMCache + NIXL | Transparent routing | GPU clusters with LMCache | +| **Router-orchestrated DI** (this proposal) | vLLM native `kv_transfer_config` | Orchestrates P→D flow | Any backend with kv_connector | + +### Proposed Changes + +**Architecture:** + +``` + ┌──────────┐ ① ┌─────────────────────────────────────┐ + │ Client │────────────────────▶│ Router (disaggregated_prefill_ │ + │ Request │ │ orchestrated) │ + └──────────┘ └──────────────────┬──────────────────┘ + │ + ② │ ③ + ┌──────────────┐ │ ┌──────────────┐ + │ Prefill │◀─────────┼─────│ Decode │ + │ Pod │ │ │ Pod │ + │ │──────────┼────▶│ │ + │ (producer) │ KV ID │ │ (consumer) │ + └──────────────┘ │ └──────────────┘ + │ + ┌──────────┐ ④ │ + │ Stream │◀───────────────────────────────────────┘ + │ Response │ + └──────────┘ +``` + +**Request Flow:** +1. Client sends `/v1/chat/completions` to Router +2. Router forwards to Prefill pod with `max_tokens=1` +3. Prefill returns KV transfer ID in `kv_transfer_params` field +4. Router forwards to Decode pod with original `max_tokens` + transfer metadata +5. Decode streams response back to client + +### Implementation Details/Notes/Constraints + +**Architecture / Components:** +- `src/vllm_router/routers/routing_logic.py` - New `DisaggregatedPrefillOrchestratedRouter` class +- `src/vllm_router/parsers/parser.py` - New CLI arguments for prefill/decode labels +- `src/vllm_router/services/request_service/request.py` - New `route_orchestrated_disaggregated_request()` function + +**Interface Changes:** + +New CLI arguments: +| Argument | Description | +|----------|-------------| +| `--routing-logic=disaggregated_prefill_orchestrated` | Enable orchestrated disaggregated routing | +| `--prefill-model-labels=prefill` | Model label to identify prefill pods | +| `--decode-model-labels=decode` | Model label to identify decode pods | + +Pod labels required: +```yaml +# Prefill deployment +metadata: + labels: + app: prefill + model: prefill + +# Decode deployment +metadata: + labels: + app: decode + model: decode +``` + +**Performance Considerations:** +- Adds one HTTP round-trip (router→prefill) before decode streaming begins +- Prefill request is non-streaming (`max_tokens=1`) to get KV transfer ID +- Decode request streams normally +- No additional memory overhead in router + +**Resource Constraints:** +- Minimal CPU overhead for JSON parsing of prefill response +- No GPU resources required by router + +### Test plans + +**Unit Tests:** +- Test `DisaggregatedPrefillOrchestratedRouter.route()` returns correct endpoints +- Test prefill/decode endpoint filtering based on model labels +- Test KV transfer params extraction from prefill response + +**Integration/E2E Tests:** +- Deploy prefill + decode + router pods +- Send chat completion request +- Verify response contains decode output +- Verify logs show correct P→D flow + +**Negative Tests:** +- No prefill endpoints available → 503 Service Unavailable +- No decode endpoints available → 503 Service Unavailable +- Prefill response missing `kv_transfer_params` → Error handling + +## Drawbacks + +- **Added latency** - One additional HTTP round-trip for prefill phase +- **Complexity** - Users must configure prefill/decode pods with correct labels +- **Backend dependency** - Requires vLLM backends to support `kv_transfer_config` + +## Alternatives + +1. **Do nothing** - Users would need a separate proxy (e.g., toy_proxy_server.py) outside production-stack +2. **Transparent routing only** - Let LMCache handle everything, but this doesn't support custom kv_connectors +3. **gRPC between P/D** - More complex, requires protocol changes + +This proposal is the best approach because it: +- Leverages existing router infrastructure +- Follows established routing_logic patterns +- Supports any kv_connector backend +- Enables KEDA-based independent scaling + +## Implementation Timeline / Phases + +**Phase 1 (Complete):** Core implementation +- DisaggregatedPrefillOrchestratedRouter class +- CLI argument parsing +- Orchestrated request flow + +**Phase 2 (TODO):** Testing & Documentation +- Unit tests +- E2E tests +- Documentation update + +## References + +- [2025 Q1 Roadmap - Support for disaggregated prefill](https://github.com/vllm-project/production-stack/issues/26) +- [vLLM Disaggregated Prefill](https://docs.vllm.ai/en/latest/serving/distributed_serving.html) diff --git a/pyproject.toml b/pyproject.toml index 066096a13..bd58fc93d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "vllm-router" dynamic = ["version"] description = "The router for vLLM" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.10" license = {text = "Apache-2.0"} classifiers = [ "Operating System :: OS Independent", diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8b12cf983..6e19c9ace 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -209,6 +209,7 @@ def parse_args(): "kvaware", "prefixaware", "disaggregated_prefill", + "disaggregated_prefill_orchestrated", ], help="The routing logic to use", ) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 1e0e4e7d3..5a0c47df2 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -55,6 +55,7 @@ class RoutingLogic(str, enum.Enum): KVAWARE = "kvaware" PREFIXAWARE = "prefixaware" DISAGGREGATED_PREFILL = "disaggregated_prefill" + DISAGGREGATED_PREFILL_ORCHESTRATED = "disaggregated_prefill_orchestrated" class RoutingInterface(metaclass=SingletonABCMeta): @@ -515,6 +516,101 @@ def route_request( return decoder_endpoints[0].url +class DisaggregatedPrefillOrchestratedRouter(RoutingInterface): + """ + Orchestrates disaggregated inference in a single request by chaining Prefill → Decode. + + Unlike DisaggregatedPrefillRouter (which requires 2 separate client requests), + this router handles the entire flow internally: + 1. Receives request from client + 2. Forwards to Prefill endpoint with kv_transfer_params to enable disaggregated mode + 3. Gets prefill response with kv_transfer_params containing KV cache metadata + 4. Extracts kv_transfer_params, sets remote_host, and forwards to Decode + 5. Streams decode response back to client + + This is designed for NxDI (Neuronx Distributed Inference) on AWS Trainium, + following NxDI's toy_proxy_server.py pattern. + + Reference: NxDI/examples/vllm/disaggregated_inference/toy_proxy_server.py + + Load balancing: Uses round-robin across available prefill and decode pods. + """ + + def __init__(self, prefill_model_labels: List[str], decode_model_labels: List[str]): + if hasattr(self, "_initialized"): + return + self.prefill_model_labels = prefill_model_labels or [] + self.decode_model_labels = decode_model_labels or [] + # Round-robin counters for load balancing across xPyD pods + self.prefill_idx = 0 + self.decode_idx = 0 + self._initialized = True + logger.info( + f"Initialized DisaggregatedPrefillOrchestratedRouter with " + f"prefill_labels={self.prefill_model_labels}, " + f"decode_labels={self.decode_model_labels}" + ) + + def _find_endpoints(self, endpoints: List[EndpointInfo]): + """Find prefill and decode endpoints based on model labels.""" + prefiller_endpoints = [ + e for e in endpoints if e.model_label in self.prefill_model_labels + ] + decoder_endpoints = [ + e for e in endpoints if e.model_label in self.decode_model_labels + ] + + if not prefiller_endpoints: + raise ValueError( + f"No prefill endpoints found with labels {self.prefill_model_labels}. " + f"Available endpoints: {[(e.url, e.model_label) for e in endpoints]}" + ) + if not decoder_endpoints: + raise ValueError( + f"No decode endpoints found with labels {self.decode_model_labels}. " + f"Available endpoints: {[(e.url, e.model_label) for e in endpoints]}" + ) + + return prefiller_endpoints, decoder_endpoints + + def select_prefill_endpoint(self, prefiller_endpoints: List[EndpointInfo]) -> EndpointInfo: + """Select prefill endpoint using round-robin load balancing.""" + if not prefiller_endpoints: + raise ValueError("No prefill endpoints available") + # Sort for consistency across requests + sorted_endpoints = sorted(prefiller_endpoints, key=lambda e: e.url) + selected = sorted_endpoints[self.prefill_idx % len(sorted_endpoints)] + self.prefill_idx += 1 + return selected + + def select_decode_endpoint(self, decoder_endpoints: List[EndpointInfo]) -> EndpointInfo: + """Select decode endpoint using round-robin load balancing.""" + if not decoder_endpoints: + raise ValueError("No decode endpoints available") + # Sort for consistency across requests + sorted_endpoints = sorted(decoder_endpoints, key=lambda e: e.url) + selected = sorted_endpoints[self.decode_idx % len(sorted_endpoints)] + self.decode_idx += 1 + return selected + + async def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + request_json: Dict, + ) -> str: + """ + This method is called by the router framework but for orchestrated routing, + we need to handle the full flow differently. This returns the prefill URL + as a placeholder - the actual orchestration happens in route_orchestrated_disaggregated_request. + """ + prefiller_endpoints, _ = self._find_endpoints(endpoints) + # Return prefill URL - actual orchestration is done in request.py + return prefiller_endpoints[0].url + + # Instead of managing a global _global_router, we can define the initialization functions as: def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs @@ -542,6 +638,11 @@ def initialize_routing_logic( return DisaggregatedPrefillRouter( kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") ) + elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL_ORCHESTRATED: + logger.info("Initializing disaggregated prefill orchestrated routing logic (NxDI)") + return DisaggregatedPrefillOrchestratedRouter( + kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") + ) else: raise ValueError(f"Invalid routing logic {routing_logic}") @@ -562,6 +663,7 @@ def get_routing_logic() -> RoutingInterface: KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, + DisaggregatedPrefillOrchestratedRouter, ): if cls in SingletonABCMeta._instances: return cls() @@ -576,6 +678,7 @@ def cleanup_routing_logic(): KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, + DisaggregatedPrefillOrchestratedRouter, ): if cls in SingletonABCMeta._instances: instance = cls() diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index be3b0b983..4cda67907 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -27,6 +27,7 @@ from vllm_router.log import init_logger from vllm_router.routers.routing_logic import ( DisaggregatedPrefillRouter, + DisaggregatedPrefillOrchestratedRouter, KvawareRouter, PrefixAwareRouter, SessionRouter, @@ -179,6 +180,13 @@ async def route_general_request( request, endpoint, background_tasks ) return response + + # Handle orchestrated disaggregated inference (NxDI pattern) + if isinstance(request.app.state.router, DisaggregatedPrefillOrchestratedRouter): + response = await route_orchestrated_disaggregated_request( + request, endpoint, background_tasks + ) + return response in_router_time = time.time() # Same as vllm, Get request_id from X-Request-Id header if available request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) @@ -368,6 +376,191 @@ async def send_request_to_decode( yield chunk +async def route_orchestrated_disaggregated_request( + request: Request, + endpoint: str, + background_tasks: BackgroundTasks, +): + """ + Orchestrated disaggregated inference following NxDI's toy_proxy_server pattern. + + Flow (matches NxDI toy_proxy_server.py): + 1. Send request to Prefill endpoint with kv_transfer_params and max_tokens=1 + 2. Get response containing kv_transfer_params with KV cache metadata + 3. Extract kv_transfer_params, set remote_host to prefill endpoint + 4. Forward kv_transfer_params to Decode endpoint + 5. Stream decode response back to client + + Reference: NxDI/examples/vllm/disaggregated_inference/toy_proxy_server.py + """ + in_router_time = time.time() + request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) + request_json = await request.json() + + logger.info(f"[{request_id}] Starting orchestrated disaggregated inference") + + # Get endpoints from service discovery + service_discovery = get_service_discovery() + endpoints = service_discovery.get_endpoint_info() + + # Use router's _find_endpoints method to get prefill and decode endpoints + router = request.app.state.router + try: + prefiller_endpoints, decoder_endpoints = router._find_endpoints(endpoints) + except ValueError as e: + logger.error(f"[{request_id}] Endpoint discovery failed: {e}") + return JSONResponse( + status_code=503, + content={"error": str(e)}, + headers={"X-Request-Id": request_id}, + ) + + # Use round-robin load balancing to select prefill and decode endpoints + prefill_endpoint = router.select_prefill_endpoint(prefiller_endpoints) + decode_endpoint = router.select_decode_endpoint(decoder_endpoints) + prefill_url = prefill_endpoint.url + decode_url = decode_endpoint.url + + logger.info(f"[{request_id}] Prefill endpoint: {prefill_url}") + logger.info(f"[{request_id}] Decode endpoint: {decode_url}") + + # Step 1: Send to Prefill with max_tokens=1 + prefill_api_url = f"{prefill_url}{endpoint}" + logger.info(f"[{request_id}] Sending prefill request to {prefill_api_url}") + + # Create prefill request with max_tokens=1 to optimize prefill step + # Also add kv_transfer_params to enable disaggregated mode on prefill + # Reference: NxDI toy_proxy_server.py + prefill_request_json = request_json.copy() + prefill_request_json["max_tokens"] = 1 + if "max_completion_tokens" in prefill_request_json: + prefill_request_json["max_completion_tokens"] = 1 + # Enable disaggregated inference mode - prefill will return kv_transfer_params + prefill_request_json["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + # Disable streaming for prefill to get full response with kv_transfer_params + prefill_request_json["stream"] = False + if "stream_options" in prefill_request_json: + del prefill_request_json["stream_options"] + + st = time.time() + is_streaming = request_json.get("stream", False) + + try: + # Use the shared aiohttp client from app state + client = request.app.state.aiohttp_client_wrapper() + + # Send to Prefill + async with client.post( + prefill_api_url, + json=prefill_request_json, + headers={ + "Content-Type": "application/json", + "X-Request-Id": request_id, + }, + timeout=aiohttp.ClientTimeout(total=300) + ) as prefill_resp: + if prefill_resp.status != 200: + error_text = await prefill_resp.text() + logger.error(f"[{request_id}] Prefill failed with status {prefill_resp.status}: {error_text}") + return JSONResponse( + status_code=prefill_resp.status, + content={"error": f"Prefill failed: {error_text}"}, + headers={"X-Request-Id": request_id}, + ) + + prefill_data = await prefill_resp.json() + et = time.time() + logger.info(f"[{request_id}] Prefill completed in {et - st:.4f}s (TTFT)") + logger.debug(f"[{request_id}] Prefill response keys: {prefill_data.keys() if isinstance(prefill_data, dict) else 'not a dict'}") + + # Step 2: Extract kv_transfer_params and send to Decode + # kv_transfer_params is the vLLM/NxDI-supported field for KV cache handoff + # Reference: NxDI toy_proxy_server.py + decode_request = request_json.copy() + kv_transfer_params = prefill_data.get("kv_transfer_params", {}) + if kv_transfer_params: + # Set remote_host to prefill endpoint for KV cache retrieval + kv_transfer_params["remote_host"] = prefill_url.split("://")[1].split(":")[0] + decode_request["kv_transfer_params"] = kv_transfer_params + else: + logger.warning(f"[{request_id}] Prefill response did not contain kv_transfer_params") + + decode_api_url = f"{decode_url}{endpoint}" + logger.info(f"[{request_id}] Sending decode request to {decode_api_url}") + + async with client.post( + decode_api_url, + json=decode_request, + headers={ + "Content-Type": "application/json", + "X-Request-Id": request_id, + }, + timeout=aiohttp.ClientTimeout(total=600) + ) as decode_resp: + if decode_resp.status != 200: + error_text = await decode_resp.text() + logger.error(f"[{request_id}] Decode failed with status {decode_resp.status}: {error_text}") + return JSONResponse( + status_code=decode_resp.status, + content={"error": f"Decode failed: {error_text}"}, + headers={"X-Request-Id": request_id}, + ) + + if is_streaming: + # For streaming, yield chunks as they arrive (true streaming) + async def generate_stream(): + try: + async for chunk in decode_resp.content.iter_any(): + if chunk: + yield chunk + finally: + curr_time = time.time() + logger.info( + f"[{request_id}] Orchestrated streaming request completed, total time = {curr_time - in_router_time:.4f}s" + ) + + return StreamingResponse( + generate_stream(), + media_type="text/event-stream", + headers={"X-Request-Id": request_id}, + ) + else: + # For non-streaming, read full response + response_data = await decode_resp.read() + + curr_time = time.time() + logger.info( + f"[{request_id}] Orchestrated request completed, total time = {curr_time - in_router_time:.4f}s" + ) + + return JSONResponse( + content=json.loads(response_data), + headers={"X-Request-Id": request_id}, + ) + + except aiohttp.ClientError as e: + logger.error(f"[{request_id}] HTTP error during orchestrated request: {e}", exc_info=True) + return JSONResponse( + status_code=503, + content={"error": f"HTTP error: {str(e)}"}, + headers={"X-Request-Id": request_id}, + ) + except Exception as e: + logger.error(f"[{request_id}] Unexpected error during orchestrated request: {e}", exc_info=True) + return JSONResponse( + status_code=500, + content={"error": f"Unexpected error: {str(e)}"}, + headers={"X-Request-Id": request_id}, + ) + + async def route_disaggregated_prefill_request( request: Request, endpoint: str,