diff --git a/.github/values-10-disagg-prefill.yaml b/.github/values-10-disagg-prefill.yaml index 548d284f..8bfb1e25 100644 --- a/.github/values-10-disagg-prefill.yaml +++ b/.github/values-10-disagg-prefill.yaml @@ -1,88 +1,90 @@ # Unified configuration for disaggregated prefill setup servingEngineSpec: - strategy: - type: Recreate enableEngine: true - runtimeClassName: "" + runtimeClassName: "nvidia" containerPort: 8000 modelSpec: # Prefill node configuration - - name: "opt125m-prefill" + - name: "llama-prefill" repository: "lmcache/vllm-openai" - tag: "2025-05-27-v1" - modelURL: "facebook/opt-125m" + tag: "nightly-2025-09-04" + modelURL: "Qwen/Qwen3-8B" replicaCount: 1 requestCPU: 8 requestMemory: "30Gi" # requestGPU: 1 pvcStorage: "50Gi" vllmConfig: - enablePrefixCaching: true - maxModelLen: 1024 - v1: 1 - gpuMemoryUtilization: 0.6 + enablePrefixCaching: false + # maxModelLen: 2048 + extraArgs: + - "--enforce-eager" + - "--disable-log-requests" lmcacheConfig: cudaVisibleDevices: "0" enabled: true kvRole: "kv_producer" + localCpu: true + maxLocalCpuSize: 5 + maxLocalDiskSize: 0 enableNixl: true + enableXpyd: true nixlRole: "sender" - nixlPeerHost: "vllm-opt125m-decode-engine-service" - nixlPeerPort: "55555" - nixlBufferSize: "1073741824" # 1GB + nixlProxyHost: "vllm-router-service" + nixlProxyPort: 7500 + nixlBufferSize: "3774873600" nixlBufferDevice: "cuda" - nixlEnableGc: true enablePD: true - cpuOffloadingBufferSize: 0 + rpcPort: "producer1" labels: - model: "opt125m-prefill" - chatTemplate: "chat.jinja2" - chatTemplateConfigMap: |- - {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} - {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} + model: "llama-prefill" + # hf_token: # Decode node configuration - - name: "opt125m-decode" + - name: "llama-decode" repository: "lmcache/vllm-openai" - tag: "2025-05-27-v1" - modelURL: "facebook/opt-125m" + tag: "nightly-2025-09-04" + modelURL: "Qwen/Qwen3-8B" replicaCount: 1 requestCPU: 8 requestMemory: "30Gi" # requestGPU: 1 pvcStorage: "50Gi" vllmConfig: - enablePrefixCaching: true - maxModelLen: 1024 - v1: 1 + enablePrefixCaching: false + # maxModelLen: 2048 + extraArgs: + - "--enforce-eager" + - "--disable-log-requests" lmcacheConfig: cudaVisibleDevices: "1" enabled: true kvRole: "kv_consumer" # Set decode node as consumer + localCpu: false + maxLocalCpuSize: 0 enableNixl: true + enableXpyd: true nixlRole: "receiver" nixlPeerHost: "0.0.0.0" - nixlPeerPort: "55555" - nixlBufferSize: "1073741824" # 1GB + nixlPeerInitPort: 7300 + nixlPeerAllocPort: 7400 + nixlBufferSize: "3774873600" nixlBufferDevice: "cuda" - nixlEnableGc: true + # nixlBackends: ["UCX"] enablePD: true + rpcPort: "consumer1" + skipLastNTokens: 1 + # hf_token: labels: - model: "opt125m-decode" - chatTemplate: "chat.jinja2" - chatTemplateConfigMap: |- - {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} - {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} + model: "llama-decode" containerSecurityContext: capabilities: add: - SYS_PTRACE - routerSpec: enableRouter: true - repository: "git-act-router" - imagePullPolicy: "IfNotPresent" - strategy: - type: Recreate + repository: "xiaokunchen/vllm-router" + tag: "08-27-v8" + imagePullPolicy: "Always" replicaCount: 1 containerPort: 8000 servicePort: 80 @@ -102,6 +104,11 @@ routerSpec: release: "router" extraArgs: - "--prefill-model-labels" - - "opt125m-prefill" + - "llama-prefill" - "--decode-model-labels" - - "opt125m-decode" + - "llama-decode" + nixlPeerHost: "vllm-llama-decode-engine-service" + nixlPeerInitPort: 7300 + nixlPeerAllocPort: 7400 + nixlProxyHost: "0.0.0.0" + nixlProxyPort: 7500 diff --git a/.github/workflows/functionality-helm-chart.yml b/.github/workflows/functionality-helm-chart.yml index 42bb9541..d3f4c995 100644 --- a/.github/workflows/functionality-helm-chart.yml +++ b/.github/workflows/functionality-helm-chart.yml @@ -58,7 +58,7 @@ jobs: run: | cd ${{ github.workspace }} kubectl config use-context minikube - sudo docker build --build-arg INSTALL_OPTIONAL_DEP=default -t localhost:5000/git-act-router -f docker/Dockerfile . + sudo docker build --build-arg INSTALL_OPTIONAL_DEP=default -t localhost:5000/git-act-router -f docker/Dockerfile.pd . sudo docker push localhost:5000/git-act-router sudo sysctl fs.protected_regular=0 minikube image load localhost:5000/git-act-router diff --git a/.github/workflows/router-e2e-test.yml b/.github/workflows/router-e2e-test.yml index 85f7c4c4..1aaec1b8 100644 --- a/.github/workflows/router-e2e-test.yml +++ b/.github/workflows/router-e2e-test.yml @@ -135,7 +135,7 @@ jobs: echo "🔨 Building router docker image" cd ${{ github.workspace }} eval "$(minikube docker-env)" - docker build --build-arg INSTALL_OPTIONAL_DEP=default -t git-act-router -f docker/Dockerfile.kvaware . + docker build --build-arg INSTALL_OPTIONAL_DEP=default -t git-act-router -f docker/Dockerfile.pd . - name: Run all k8s discovery routing tests run: | diff --git a/docker/Dockerfile.pd b/docker/Dockerfile.pd new file mode 100644 index 00000000..fbf41f03 --- /dev/null +++ b/docker/Dockerfile.pd @@ -0,0 +1,31 @@ +FROM python:3.12-slim + +WORKDIR /app + +# hadolint ignore=DL3008 +RUN --mount=type=cache,target=/var/lib/apt --mount=type=cache,target=/var/cache/apt \ + apt-get update && \ + apt-get install -y --no-install-recommends git curl && \ + rm -rf /var/lib/apt/lists/* && \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + /root/.local/bin/uv venv /opt/venv + +# Copy the pyproject.toml and the git metadata first (leverage Docker layer caching) +COPY pyproject.toml . +COPY .git/ .git/ + +# Copy the rest of the application code +COPY src/ src/ + +ARG INSTALL_OPTIONAL_DEP=semantic_cache,lmcache +ENV INSTALL_OPTIONAL_DEP=${INSTALL_OPTIONAL_DEP} + +# hadolint ignore=SC1091 +RUN . /opt/venv/bin/activate && \ + /root/.local/bin/uv pip install --upgrade --no-cache-dir pip setuptools_scm && \ + /root/.local/bin/uv pip install --no-cache-dir .[$INSTALL_OPTIONAL_DEP] && \ + /root/.local/bin/uv pip install zmq msgspec + +# Set the entrypoint +ENTRYPOINT ["/opt/venv/bin/vllm-router"] +CMD [] diff --git a/docs/source/developer_guide/docker.rst b/docs/source/developer_guide/docker.rst index b035397f..19ffa56d 100644 --- a/docs/source/developer_guide/docker.rst +++ b/docs/source/developer_guide/docker.rst @@ -10,4 +10,4 @@ Run this command from the root folder path of the project: .. code-block:: bash - docker build -t : -f docker/Dockerfile . + docker build -t : -f docker/Dockerfile.pd . diff --git a/helm/templates/deployment-router.yaml b/helm/templates/deployment-router.yaml index cef01271..b1ad9f01 100644 --- a/helm/templates/deployment-router.yaml +++ b/helm/templates/deployment-router.yaml @@ -113,6 +113,26 @@ spec: - "--lmcache-controller-port" - "{{ .Values.routerSpec.lmcacheControllerPort }}" {{- end }} + {{- if .Values.routerSpec.nixlPeerHost }} + - "--nixl-peer-host" + - "{{ .Values.routerSpec.nixlPeerHost }}" + {{- end }} + {{- if .Values.routerSpec.nixlPeerInitPort }} + - "--nixl-peer-init-port" + - "{{ .Values.routerSpec.nixlPeerInitPort }}" + {{- end }} + {{- if .Values.routerSpec.nixlPeerAllocPort }} + - "--nixl-peer-alloc-port" + - "{{ .Values.routerSpec.nixlPeerAllocPort }}" + {{- end }} + {{- if .Values.routerSpec.nixlProxyHost }} + - "--nixl-proxy-host" + - "{{ .Values.routerSpec.nixlProxyHost }}" + {{- end }} + {{- if .Values.routerSpec.nixlProxyPort }} + - "--nixl-proxy-port" + - "{{ .Values.routerSpec.nixlProxyPort }}" + {{- end }} {{- if .Values.routerSpec.resources }} resources: {{- if .Values.routerSpec.resources.requests }} @@ -135,6 +155,16 @@ spec: containerPort: {{ .Values.routerSpec.containerPort }} - name: "lmcache-port" containerPort: 9000 + - name: pd-port-1 + containerPort: 7100 + - name: pd-port-2 + containerPort: 7200 + - name: pd-port-3 + containerPort: 7300 + - name: pd-port-4 + containerPort: 7400 + - name: pd-port-5 + containerPort: 7500 livenessProbe: initialDelaySeconds: 30 periodSeconds: 5 diff --git a/helm/templates/deployment-vllm-multi.yaml b/helm/templates/deployment-vllm-multi.yaml index f162822d..bdc7e228 100644 --- a/helm/templates/deployment-vllm-multi.yaml +++ b/helm/templates/deployment-vllm-multi.yaml @@ -183,7 +183,11 @@ spec: {{- if $modelSpec.lmcacheConfig.enabled }} {{- if hasKey $modelSpec.lmcacheConfig "enablePD" }} - "--kv-transfer-config" - - '{"kv_connector":"LMCacheConnectorV1","kv_role":"{{ $kv_role }}","kv_connector_extra_config":{"discard_partial_chunks": false, "lmcache_rpc_port": {{ $modelSpec.lmcacheConfig.nixlRole | quote }}}}' + {{- if eq $kv_role "kv_producer" }} + - '{"kv_connector":"LMCacheConnectorV1","kv_role":"{{ $kv_role }}","kv_connector_extra_config":{"discard_partial_chunks": false, "lmcache_rpc_port": "{{ $modelSpec.lmcacheConfig.rpcPort | default "producer1" }}"}}' + {{- else }} + - '{"kv_connector":"LMCacheConnectorV1","kv_role":"{{ $kv_role }}","kv_connector_extra_config":{"discard_partial_chunks": false, "lmcache_rpc_port": "{{ $modelSpec.lmcacheConfig.rpcPort | default "consumer1" }}", "skip_last_n_tokens": {{ $modelSpec.lmcacheConfig.skipLastNTokens | default 1 }}}}' + {{- end }} {{- else if and (hasKey $modelSpec.vllmConfig "v0") (eq (toString $modelSpec.vllmConfig.v0) "1") }} - "--kv-transfer-config" - '{"kv_connector":"LMCacheConnector","kv_role":"{{ $kv_role }}"}' @@ -259,16 +263,18 @@ spec: value: "True" - name: VLLM_RPC_TIMEOUT value: "1000000" + - name: PYTHONHASHSEED + value: "0" + - name: VLLM_ENABLE_V1_MULTIPROCESSING + value: "1" + - name: VLLM_WORKER_MULTIPROC_METHOD + value: "spawn" {{- end }} {{- if hasKey $modelSpec.lmcacheConfig "cudaVisibleDevices" }} - name: CUDA_VISIBLE_DEVICES value: {{ $modelSpec.lmcacheConfig.cudaVisibleDevices | quote }} {{- end }} {{- if and (hasKey $modelSpec.lmcacheConfig "enablePD") ($modelSpec.lmcacheConfig.enablePD) }} - - name: LMCACHE_LOCAL_CPU - value: "False" - - name: LMCACHE_MAX_LOCAL_CPU_SIZE - value: "0" - name: LMCACHE_REMOTE_SERDE value: "NULL" - name: UCX_TLS @@ -281,14 +287,29 @@ spec: - name: LMCACHE_NIXL_ROLE value: {{ $modelSpec.lmcacheConfig.nixlRole | quote }} {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "enableXpyd" }} + - name: LMCACHE_ENABLE_XPYD + value: {{ ternary "True" "False" $modelSpec.lmcacheConfig.enableXpyd | quote }} + {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "nixlProxyHost" }} + - name: LMCACHE_NIXL_PROXY_HOST + value: {{ $modelSpec.lmcacheConfig.nixlProxyHost | quote }} + {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "nixlProxyPort" }} + - name: LMCACHE_NIXL_PROXY_PORT + value: {{ $modelSpec.lmcacheConfig.nixlProxyPort | quote }} + {{- end }} {{- if hasKey $modelSpec.lmcacheConfig "nixlPeerHost" }} - - name: LMCACHE_NIXL_RECEIVER_HOST - # value: "0.0.0.0" + - name: LMCACHE_NIXL_PEER_HOST value: {{ $modelSpec.lmcacheConfig.nixlPeerHost | quote }} {{- end }} - {{- if hasKey $modelSpec.lmcacheConfig "nixlPeerPort" }} - - name: LMCACHE_NIXL_RECEIVER_PORT - value: {{ $modelSpec.lmcacheConfig.nixlPeerPort | quote }} + {{- if hasKey $modelSpec.lmcacheConfig "nixlPeerInitPort" }} + - name: LMCACHE_NIXL_PEER_INIT_PORT + value: {{ $modelSpec.lmcacheConfig.nixlPeerInitPort | quote }} + {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "nixlPeerAllocPort" }} + - name: LMCACHE_NIXL_PEER_ALLOC_PORT + value: {{ $modelSpec.lmcacheConfig.nixlPeerAllocPort | quote }} {{- end }} {{- if hasKey $modelSpec.lmcacheConfig "nixlBufferSize" }} - name: LMCACHE_NIXL_BUFFER_SIZE @@ -298,22 +319,26 @@ spec: - name: LMCACHE_NIXL_BUFFER_DEVICE value: {{ $modelSpec.lmcacheConfig.nixlBufferDevice | quote }} {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "nixlBackends" }} + - name: LMCACHE_NIXL_BACKENDS + value: {{ $modelSpec.lmcacheConfig.nixlBackends | toJson | quote }} + {{- end }} {{- if hasKey $modelSpec.lmcacheConfig "nixlEnableGc" }} - name: LMCACHE_NIXL_ENABLE_GC value: {{ ternary "True" "False" $modelSpec.lmcacheConfig.nixlEnableGc | quote }} {{- end }} {{- end }} - {{- if hasKey $modelSpec.lmcacheConfig "cpuOffloadingBufferSize" }} - {{- if gt (int $modelSpec.lmcacheConfig.cpuOffloadingBufferSize) 0 }} + {{- if hasKey $modelSpec.lmcacheConfig "localCpu" }} - name: LMCACHE_LOCAL_CPU - value: "True" + value: {{ ternary "True" "False" $modelSpec.lmcacheConfig.localCpu | quote }} + {{- end }} + {{- if hasKey $modelSpec.lmcacheConfig "maxLocalCpuSize" }} - name: LMCACHE_MAX_LOCAL_CPU_SIZE - value: "{{ $modelSpec.lmcacheConfig.cpuOffloadingBufferSize }}" - {{- end}} + value: {{ $modelSpec.lmcacheConfig.maxLocalCpuSize | quote }} {{- end }} - {{- if hasKey $modelSpec.lmcacheConfig "diskOffloadingBufferSize" }} + {{- if hasKey $modelSpec.lmcacheConfig "maxLocalDiskSize" }} - name: LMCACHE_MAX_LOCAL_DISK_SIZE - value: "{{ $modelSpec.lmcacheConfig.diskOffloadingBufferSize }}" + value: {{ $modelSpec.lmcacheConfig.maxLocalDiskSize | quote }} {{- end }} {{- if .Values.cacheserverSpec }} - name: LMCACHE_REMOTE_URL @@ -360,6 +385,16 @@ spec: containerPort: 55555 - name: ucx-port containerPort: 9999 + - name: pd-port-1 + containerPort: 7100 + - name: pd-port-2 + containerPort: 7200 + - name: pd-port-3 + containerPort: 7300 + - name: pd-port-4 + containerPort: 7400 + - name: pd-port-5 + containerPort: 7500 {{- include "chart.probes" . | indent 10 }} resources: {{- include "chart.resources" $modelSpec | nindent 12 }} {{- if or (hasKey $modelSpec "pvcStorage") (and $modelSpec.vllmConfig (hasKey $modelSpec.vllmConfig "tensorParallelSize")) (hasKey $modelSpec "chatTemplate") (hasKey $modelSpec "extraVolumeMounts") }} diff --git a/helm/templates/service-router.yaml b/helm/templates/service-router.yaml index 1340eaf3..1aa83151 100644 --- a/helm/templates/service-router.yaml +++ b/helm/templates/service-router.yaml @@ -20,6 +20,26 @@ spec: port: 9000 targetPort: lmcache-port protocol: TCP + - name: pd-port-1 + port: 7100 + targetPort: pd-port-1 + protocol: TCP + - name: pd-port-2 + port: 7200 + targetPort: pd-port-2 + protocol: TCP + - name: pd-port-3 + port: 7300 + targetPort: pd-port-3 + protocol: TCP + - name: pd-port-4 + port: 7400 + targetPort: pd-port-4 + protocol: TCP + - name: pd-port-5 + port: 7500 + targetPort: pd-port-5 + protocol: TCP selector: {{- include "chart.routerLabels" . | nindent 4 }} {{- end }} diff --git a/helm/templates/service-vllm.yaml b/helm/templates/service-vllm.yaml index e9220d38..bd6a7b85 100644 --- a/helm/templates/service-vllm.yaml +++ b/helm/templates/service-vllm.yaml @@ -23,6 +23,26 @@ spec: port: 9999 targetPort: ucx-port protocol: TCP + - name: pd-port-1 + port: 7100 + targetPort: pd-port-1 + protocol: TCP + - name: pd-port-2 + port: 7200 + targetPort: pd-port-2 + protocol: TCP + - name: pd-port-3 + port: 7300 + targetPort: pd-port-3 + protocol: TCP + - name: pd-port-4 + port: 7400 + targetPort: pd-port-4 + protocol: TCP + - name: pd-port-5 + port: 7500 + targetPort: pd-port-5 + protocol: TCP selector: model: "{{ $modelSpec.name }}" helm-release-name: "{{ $.Release.Name }}" diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 0713e9c0..bc1995b7 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import threading from contextlib import asynccontextmanager @@ -32,6 +33,7 @@ from vllm_router.routers.main_router import main_router from vllm_router.routers.metrics_router import metrics_router from vllm_router.routers.routing_logic import ( + DisaggregatedPrefillRouter, get_routing_logic, initialize_routing_logic, ) @@ -43,6 +45,10 @@ from vllm_router.services.batch_service import initialize_batch_processor from vllm_router.services.callbacks_service.callbacks import configure_custom_callbacks from vllm_router.services.files_service import initialize_storage +from vllm_router.services.request_service.request import ( + start_zmq_task, + stop_zmq_task, +) from vllm_router.services.request_service.rewriter import ( get_request_rewriter, ) @@ -90,7 +96,23 @@ async def lifespan(app: FastAPI): if hasattr(service_discovery, "initialize_client_sessions"): await service_discovery.initialize_client_sessions() - yield + app.state.event_loop = asyncio.get_event_loop() + + # only start the ZMQ task if the routing logic is RoutingLogic.DISAGGREGATED_PREFILL + if isinstance(app.state.router, DisaggregatedPrefillRouter): + logger.info( + "Starting ZMQ task because the routing logic is RoutingLogic.DISAGGREGATED_PREFILL" + ) + # Start the ZMQ task + await start_zmq_task() + + yield + + # Stop the ZMQ task + await stop_zmq_task() + else: + yield + await app.state.aiohttp_client_wrapper.stop() # Close the threaded-components @@ -270,6 +292,7 @@ def initialize_all(app: FastAPI, args): app.state.request_stats_monitor = get_request_stats_monitor() app.state.router = get_routing_logic() app.state.request_rewriter = get_request_rewriter() + app.state.args = args app = FastAPI(lifespan=lifespan) diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8b12cf98..14778056 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -379,6 +379,35 @@ def parse_args(): help="The threshold for kv-aware routing.", ) + parser.add_argument( + "--nixl-peer-host", + type=str, + help="The hostname or IP address of the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-peer-init-port", + type=int, + default=7300, + help="The initialization port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-peer-alloc-port", + type=int, + default=7400, + help="The allocation port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-proxy-host", + type=str, + help="The hostname or IP address for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-proxy-port", + type=int, + default=7500, + help="The port for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.", + ) + args = parser.parse_args() args = load_initial_config_from_config_file_if_required(parser, args) diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 8ee8f089..d4bf3b22 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -330,11 +330,13 @@ async def initialize_client_sessions(self) -> None: endpoint_infos = self.get_endpoint_info() for endpoint_info in endpoint_infos: if endpoint_info.model_label in self.prefill_model_labels: + # TODO: fix Unclosed client session self.app.state.prefill_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), ) elif endpoint_info.model_label in self.decode_model_labels: + # TODO: fix Unclosed client session self.app.state.decode_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), @@ -662,6 +664,14 @@ def _add_engine( # Store model information in the endpoint info self.available_engines[engine_name].model_info = model_info + try: + fut = asyncio.run_coroutine_threadsafe( + self.initialize_client_sessions(), self.app.state.event_loop + ) + fut.result() + except Exception as e: + logger.error(f"Error initializing client sessions: {e}") + def _delete_engine(self, engine_name: str): logger.info(f"Serving engine {engine_name} is deleted") with self.available_engines_lock: @@ -748,11 +758,14 @@ async def initialize_client_sessions(self) -> None: endpoint_infos = self.get_endpoint_info() for endpoint_info in endpoint_infos: if endpoint_info.model_label in self.prefill_model_labels: + # TODO: fix Unclosed client session self.app.state.prefill_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), ) + elif endpoint_info.model_label in self.decode_model_labels: + # TODO: fix Unclosed client session self.app.state.decode_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), @@ -1080,6 +1093,14 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str # Store model information in the endpoint info self.available_engines[engine_name].model_info = model_info + try: + fut = asyncio.run_coroutine_threadsafe( + self.initialize_client_sessions(), self.app.state.event_loop + ) + fut.result() + except Exception as e: + logger.error(f"Error initializing client sessions: {e}") + def _delete_engine(self, engine_name: str): logger.info(f"Serving engine {engine_name} is deleted") with self.available_engines_lock: @@ -1165,11 +1186,13 @@ async def initialize_client_sessions(self) -> None: endpoint_infos = self.get_endpoint_info() for endpoint_info in endpoint_infos: if endpoint_info.model_label in self.prefill_model_labels: + # TODO: fix Unclosed client session self.app.state.prefill_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), ) elif endpoint_info.model_label in self.decode_model_labels: + # TODO: fix Unclosed client session self.app.state.decode_client = aiohttp.ClientSession( base_url=endpoint_info.url, timeout=aiohttp.ClientTimeout(total=None), diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 0c500571..9986d097 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + # --- Request Processing & Routing --- import json import os @@ -20,8 +22,18 @@ from typing import Optional import aiohttp +import msgspec +import zmq +import zmq.asyncio from fastapi import BackgroundTasks, HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse + +try: + from lmcache.v1.storage_backend.connector.nixl_connector_v3 import ( + NixlMsg, + ) +except ImportError: + pass from requests import JSONDecodeError from vllm_router.log import init_logger @@ -50,6 +62,74 @@ logger = init_logger(__name__) +finished_reqs = set() +run_proxy = True +zmq_ctx = zmq.asyncio.Context() + + +async def zmq_pull_server(): + try: + socket = zmq_ctx.socket(zmq.PULL) + try: + from vllm_router.app import app + + proxy_host = app.state.args.nixl_proxy_host + proxy_port = app.state.args.nixl_proxy_port + except Exception as e: + logger.error(f"Failed to get proxy host and port from app state: {e}") + proxy_url = f"{proxy_host}:{proxy_port}" + socket.bind(f"tcp://{proxy_url}") + logger.info(f"ZMQ proxy server started on {proxy_url}") + except Exception as e: + logger.error(f"Failed to bind ZMQ socket to {proxy_url}: {e}") + socket.close() + return + + while run_proxy: + try: + msg_bytes = await socket.recv() + msg = msgspec.msgpack.decode(msg_bytes, type=NixlMsg) + req_id = msg.req_id + finished_reqs.add(req_id) + logger.info(f"Prefill of req {req_id} done.") + except zmq.Again: + await asyncio.sleep(0.01) # Avoid busy loop + except Exception as e: + logger.error(f"ZMQ Error in message processing: {e}") + break + + socket.close() + logger.info("ZMQ PULL server stopped.") + + +# ZMQ task will be created in the FastAPI lifespan manager +zmq_task = None + + +async def start_zmq_task(): + """Start the ZMQ pull server task.""" + global zmq_task + if zmq_task is None: + zmq_task = asyncio.create_task(zmq_pull_server()) + logger.info("ZMQ task started") + + # Add a small delay to allow the task to start and potentially log any errors + await asyncio.sleep(0.1) + + +async def stop_zmq_task(): + """Stop the ZMQ pull server task.""" + global zmq_task, run_proxy + if zmq_task is not None: + run_proxy = False + zmq_task.cancel() + try: + await zmq_task + except asyncio.CancelledError: + pass + zmq_task = None + logger.info("ZMQ task stopped") + # TODO: (Brian) check if request is json beforehand async def process_request( @@ -164,7 +244,7 @@ async def route_general_request( # Same as vllm, Get request_id from X-Request-Id header if available request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) request_body = await request.body() - request_json = json.loads(request_body) + request_json = await request.json() # TODO (ApostaC): merge two awaits into one if request.query_params: request_endpoint = request.query_params.get("id") @@ -204,6 +284,7 @@ async def route_general_request( status_code=400, detail="Request body is not JSON parsable." ) + # TODO (ApostaC): merge two awaits into one service_discovery = get_service_discovery() endpoints = service_discovery.get_endpoint_info() @@ -302,6 +383,7 @@ async def route_general_request( ) +# TODO: Combine with send_request_to_tokenizer and send_request_to_decode async def send_request_to_prefiller( client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str ): @@ -321,6 +403,22 @@ async def send_request_to_prefiller( return await response.json() +async def send_request_to_tokenizer( + client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a tokenizer service. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + async with client.post(endpoint, json=req_data, headers=headers) as response: + response.raise_for_status() + return await response.json() + + async def send_request_to_decode( client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str ): @@ -336,6 +434,13 @@ async def send_request_to_decode( yield chunk +async def wait_decode_kv_ready(req_id: str): + while req_id not in finished_reqs: + await asyncio.sleep(0.0001) # sleep for 0.1 ms + logger.debug(f"Prefill node signaled kv ready for req {req_id}") + finished_reqs.remove(req_id) + + async def route_disaggregated_prefill_request( request: Request, endpoint: str, @@ -347,10 +452,76 @@ async def route_disaggregated_prefill_request( request_json = await request.json() orig_max_tokens = request_json.get("max_tokens", 0) - request_json["max_tokens"] = 1 + stream_options = request_json.pop("stream_options", None) + + # # Check if client sessions are initialized, if not, try to initialize them + # if not hasattr(request.app.state, 'prefill_client') or request.app.state.prefill_client is None: + # logger.warning("prefill_client not initialized, attempting to initialize client sessions") + # try: + # from vllm_router.service_discovery import get_service_discovery + # service_discovery = get_service_discovery() + # if hasattr(service_discovery, '_reinitialize_client_sessions'): + # logger.info("In route_disaggregated_prefill_request: Calling _reinitialize_client_sessions") + # await service_discovery._reinitialize_client_sessions() + # logger.info("Successfully initialized client sessions") + # else: + # logger.error("Service discovery does not have _reinitialize_client_sessions method") + # except Exception as e: + # logger.error(f"Failed to initialize client sessions: {e}") + # return JSONResponse( + # status_code=500, + # content={ + # "error": { + # "message": "Failed to initialize client sessions", + # "type": "initialization_error", + # "code": 500, + # } + # }, + # headers={"X-Request-Id": request_id}, + # ) + st = time.time() try: - await send_request_to_prefiller( + # Step 1: Tokenize the prompt + # request_json {'model': 'facebook/opt-125m', 'prompt': 'What date is today?', 'max_tokens': 20, 'temperature': 0.0} + # # print every key-value pair in prefill_client + # for key, value in request.app.state.prefill_client.__dict__.items(): + # print(f"{key}: {value}") + + tokenize_output = await send_request_to_tokenizer( + request.app.state.prefill_client, + "/tokenize", + {"prompt": request_json["prompt"]}, + request_id, + ) + # tokenize_output {'count': 6, 'max_model_len': 2048, 'tokens': [2, 2264, 1248, 16, 452, 116], 'token_strs': None} + + # Update request with tokenized prompt + request_json["prompt"] = tokenize_output["tokens"] + request_json["max_tokens"] = 1 + + # Step 2: Create disagg_spec for KV transfer + disagg_spec = { + "req_id": request_id, + "receiver_host": request.app.state.args.nixl_peer_host, + "receiver_init_port": [request.app.state.args.nixl_peer_init_port], + "receiver_alloc_port": [request.app.state.args.nixl_peer_alloc_port], + } + # disagg_spec = { + # "req_id": request_id, + # "receiver_host": "0.0.0.0", + # "receiver_init_port": [7300], + # "receiver_alloc_port": [7400], + # } + + request_json["kv_transfer_params"] = { + "ret_first_tok": True, + "disagg_spec": disagg_spec, + } + request_json["stream"] = False + + # Step 3: Send to prefiller + prefill_output = await send_request_to_prefiller( request.app.state.prefill_client, endpoint, request_json, request_id ) et = time.time() @@ -358,7 +529,15 @@ async def route_disaggregated_prefill_request( logger.info( f"Routing request {request_id} with session id None to {request.app.state.prefill_client._base_url} at {et}, process time = {et - in_router_time:.4f}" ) - request_json["max_tokens"] = orig_max_tokens + + # Step 4: Prepare decode request + request_json["max_tokens"] = orig_max_tokens - 1 + request_json["prompt"].append(prefill_output["kv_transfer_params"]["first_tok"]) + request_json.pop("kv_transfer_params") + request_json["stream"] = True + if stream_options is not None: + request_json["stream_options"] = stream_options + except aiohttp.ClientResponseError as e: logger.error(f"HTTP error in prefiller: {e}", exc_info=True) return JSONResponse( @@ -388,6 +567,30 @@ async def route_disaggregated_prefill_request( async def generate_stream(): try: + # Yield initial chunk with prefill data + head_chunk = { + "id": prefill_output["id"], + "object": "text_completion", + "created": prefill_output["created"], + "model": prefill_output["model"], + "choices": [ + { + "index": 0, + "text": prefill_output["choices"][0]["text"], + "logprobs": None, + "finish_reason": None, + "stop_reason": None, + } + ], + "usage": None, + } + yield ( + "data: " + json.dumps(head_chunk, separators=(",", ":")) + "\n\n" + ).encode() + + await wait_decode_kv_ready(request_id) + + # Stream the rest from decode service async for chunk in send_request_to_decode( request.app.state.decode_client, endpoint, request_json, request_id ): diff --git a/tutorials/assets/values-16-disagg-prefill.yaml b/tutorials/assets/values-16-disagg-prefill.yaml index 35bcf410..8bfb1e25 100644 --- a/tutorials/assets/values-16-disagg-prefill.yaml +++ b/tutorials/assets/values-16-disagg-prefill.yaml @@ -1,64 +1,79 @@ # Unified configuration for disaggregated prefill setup servingEngineSpec: enableEngine: true - runtimeClassName: "" + runtimeClassName: "nvidia" containerPort: 8000 modelSpec: # Prefill node configuration - name: "llama-prefill" repository: "lmcache/vllm-openai" - tag: "2025-05-27-v1" - modelURL: "meta-llama/Llama-3.1-8B-Instruct" + tag: "nightly-2025-09-04" + modelURL: "Qwen/Qwen3-8B" replicaCount: 1 requestCPU: 8 requestMemory: "30Gi" # requestGPU: 1 pvcStorage: "50Gi" vllmConfig: - enablePrefixCaching: true - maxModelLen: 32000 + enablePrefixCaching: false + # maxModelLen: 2048 + extraArgs: + - "--enforce-eager" + - "--disable-log-requests" lmcacheConfig: cudaVisibleDevices: "0" enabled: true kvRole: "kv_producer" + localCpu: true + maxLocalCpuSize: 5 + maxLocalDiskSize: 0 enableNixl: true + enableXpyd: true nixlRole: "sender" - nixlPeerHost: "vllm-llama-decode-engine-service" - nixlPeerPort: "55555" - nixlBufferSize: "1073741824" # 1GB + nixlProxyHost: "vllm-router-service" + nixlProxyPort: 7500 + nixlBufferSize: "3774873600" nixlBufferDevice: "cuda" - nixlEnableGc: true enablePD: true - cpuOffloadingBufferSize: 0 - hf_token: + rpcPort: "producer1" labels: model: "llama-prefill" + # hf_token: # Decode node configuration - name: "llama-decode" repository: "lmcache/vllm-openai" - tag: "2025-05-27-v1" - modelURL: "meta-llama/Llama-3.1-8B-Instruct" + tag: "nightly-2025-09-04" + modelURL: "Qwen/Qwen3-8B" replicaCount: 1 requestCPU: 8 requestMemory: "30Gi" # requestGPU: 1 pvcStorage: "50Gi" vllmConfig: - enablePrefixCaching: true - maxModelLen: 32000 + enablePrefixCaching: false + # maxModelLen: 2048 + extraArgs: + - "--enforce-eager" + - "--disable-log-requests" lmcacheConfig: cudaVisibleDevices: "1" enabled: true kvRole: "kv_consumer" # Set decode node as consumer + localCpu: false + maxLocalCpuSize: 0 enableNixl: true + enableXpyd: true nixlRole: "receiver" nixlPeerHost: "0.0.0.0" - nixlPeerPort: "55555" - nixlBufferSize: "1073741824" # 1GB + nixlPeerInitPort: 7300 + nixlPeerAllocPort: 7400 + nixlBufferSize: "3774873600" nixlBufferDevice: "cuda" - nixlEnableGc: true + # nixlBackends: ["UCX"] enablePD: true - hf_token: + rpcPort: "consumer1" + skipLastNTokens: 1 + # hf_token: labels: model: "llama-decode" containerSecurityContext: @@ -67,8 +82,9 @@ servingEngineSpec: - SYS_PTRACE routerSpec: enableRouter: true - repository: "lmcache/lmstack-router" - tag: "pd" + repository: "xiaokunchen/vllm-router" + tag: "08-27-v8" + imagePullPolicy: "Always" replicaCount: 1 containerPort: 8000 servicePort: 80 @@ -91,3 +107,8 @@ routerSpec: - "llama-prefill" - "--decode-model-labels" - "llama-decode" + nixlPeerHost: "vllm-llama-decode-engine-service" + nixlPeerInitPort: 7300 + nixlPeerAllocPort: 7400 + nixlProxyHost: "0.0.0.0" + nixlProxyPort: 7500