|
| 1 | +""" |
| 2 | +This file serves as a documentation example and CI test for data parallel + prefill-decode disaggregation. |
| 3 | +
|
| 4 | +Structure: |
| 5 | +1. Monkeypatch setup: Ensures serve.run is non-blocking and removes accelerator requirements for CI testing. |
| 6 | +2. Docs example (between __dp_pd_example_start/end__): Embedded in Sphinx docs via literalinclude. |
| 7 | +3. Test validation (deployment status polling + cleanup) |
| 8 | +""" |
| 9 | + |
| 10 | +import time |
| 11 | +from ray import serve |
| 12 | +from ray.serve.schema import ApplicationStatus |
| 13 | +from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME |
| 14 | +from ray.serve import llm |
| 15 | +from ray.serve.llm.deployment import PDProxyServer |
| 16 | +from ray.serve.llm.ingress import OpenAiIngress, make_fastapi_ingress |
| 17 | + |
| 18 | +# Check if NIXL is available (required for NixlConnector) |
| 19 | +try: |
| 20 | + import nixl # noqa: F401 |
| 21 | + NIXL_AVAILABLE = True |
| 22 | +except ImportError: |
| 23 | + NIXL_AVAILABLE = False |
| 24 | + |
| 25 | +if not NIXL_AVAILABLE: |
| 26 | + raise ImportError( |
| 27 | + "NIXL is required for this example but is not installed. " |
| 28 | + "Install it with: pip install nixl or uv pip install nixl" |
| 29 | + ) |
| 30 | + |
| 31 | +_original_serve_run = serve.run |
| 32 | +_original_build_dp_deployment = llm.build_dp_deployment |
| 33 | + |
| 34 | + |
| 35 | +def _non_blocking_serve_run(app, **kwargs): |
| 36 | + """Forces blocking=False for testing""" |
| 37 | + kwargs["blocking"] = False |
| 38 | + return _original_serve_run(app, **kwargs) |
| 39 | + |
| 40 | + |
| 41 | +def _testing_build_dp_deployment(llm_config, **kwargs): |
| 42 | + """Removes accelerator requirements for testing""" |
| 43 | + if llm_config.accelerator_type is not None: |
| 44 | + llm_config.accelerator_type = None |
| 45 | + return _original_build_dp_deployment(llm_config, **kwargs) |
| 46 | + |
| 47 | + |
| 48 | +serve.run = _non_blocking_serve_run |
| 49 | +llm.build_dp_deployment = _testing_build_dp_deployment |
| 50 | + |
| 51 | +# __dp_pd_example_start__ |
| 52 | +from ray import serve |
| 53 | +from ray.serve.llm import LLMConfig, build_dp_deployment |
| 54 | +from ray.serve.llm.deployment import PDProxyServer |
| 55 | +from ray.serve.llm.ingress import OpenAiIngress, make_fastapi_ingress |
| 56 | + |
| 57 | +# Configure prefill with data parallel attention |
| 58 | +prefill_config = LLMConfig( |
| 59 | + model_loading_config={ |
| 60 | + "model_id": "Qwen/Qwen2.5-0.5B-Instruct" |
| 61 | + }, |
| 62 | + engine_kwargs={ |
| 63 | + "data_parallel_size": 2, # 2 DP replicas for prefill |
| 64 | + "tensor_parallel_size": 1, |
| 65 | + "kv_transfer_config": { |
| 66 | + "kv_connector": "NixlConnector", |
| 67 | + "kv_role": "kv_both", |
| 68 | + } |
| 69 | + }, |
| 70 | + experimental_configs={ |
| 71 | + "dp_size_per_node": 2, |
| 72 | + }, |
| 73 | +) |
| 74 | + |
| 75 | +# Configure decode with data parallel attention |
| 76 | +decode_config = LLMConfig( |
| 77 | + model_loading_config={ |
| 78 | + "model_id": "Qwen/Qwen2.5-0.5B-Instruct" |
| 79 | + }, |
| 80 | + engine_kwargs={ |
| 81 | + "data_parallel_size": 2, # 2 DP replicas for decode (adjusted for 4 GPU limit) |
| 82 | + "tensor_parallel_size": 1, |
| 83 | + "kv_transfer_config": { |
| 84 | + "kv_connector": "NixlConnector", |
| 85 | + "kv_role": "kv_both", |
| 86 | + } |
| 87 | + }, |
| 88 | + experimental_configs={ |
| 89 | + "dp_size_per_node": 2, |
| 90 | + }, |
| 91 | +) |
| 92 | + |
| 93 | +# Build prefill and decode deployments with DP |
| 94 | +prefill_deployment = build_dp_deployment(prefill_config, name_prefix="Prefill:") |
| 95 | +decode_deployment = build_dp_deployment(decode_config, name_prefix="Decode:") |
| 96 | + |
| 97 | +# Create PDProxyServer to coordinate between prefill and decode |
| 98 | +proxy_options = PDProxyServer.get_deployment_options(prefill_config, decode_config) |
| 99 | +proxy_deployment = serve.deployment(PDProxyServer).options(**proxy_options).bind( |
| 100 | + prefill_server=prefill_deployment, |
| 101 | + decode_server=decode_deployment, |
| 102 | +) |
| 103 | + |
| 104 | +# Create OpenAI-compatible ingress |
| 105 | +ingress_options = OpenAiIngress.get_deployment_options([prefill_config, decode_config]) |
| 106 | +ingress_cls = make_fastapi_ingress(OpenAiIngress) |
| 107 | +ingress_deployment = serve.deployment(ingress_cls).options(**ingress_options).bind( |
| 108 | + llm_deployments=[proxy_deployment] |
| 109 | +) |
| 110 | + |
| 111 | +# Deploy the application |
| 112 | +serve.run(ingress_deployment, blocking=True) |
| 113 | +# __dp_pd_example_end__ |
| 114 | + |
| 115 | +status = ApplicationStatus.NOT_STARTED |
| 116 | +timeout_seconds = 300 # Longer timeout for DP+PD setup |
| 117 | +start_time = time.time() |
| 118 | + |
| 119 | +while ( |
| 120 | + status != ApplicationStatus.RUNNING and time.time() - start_time < timeout_seconds |
| 121 | +): |
| 122 | + status = serve.status().applications[SERVE_DEFAULT_APP_NAME].status |
| 123 | + |
| 124 | + if status in [ApplicationStatus.DEPLOY_FAILED, ApplicationStatus.UNHEALTHY]: |
| 125 | + raise AssertionError(f"Deployment failed with status: {status}") |
| 126 | + |
| 127 | + time.sleep(1) |
| 128 | + |
| 129 | +if status != ApplicationStatus.RUNNING: |
| 130 | + raise AssertionError( |
| 131 | + f"Deployment failed to reach RUNNING status within {timeout_seconds}s. Current status: {status}" |
| 132 | + ) |
| 133 | + |
| 134 | +serve.shutdown() |
| 135 | + |
0 commit comments