Skip to content

Commit 81f3b9c

Browse files
[0.9.1-DEV][BUGFIX] BugFix: Resolve the issue of waiting queue accumulation when requests are canceled. (#2502)
### What this PR does / why we need it? Resolve the issue of waiting queue accumulation when requests are canceled. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? by ci. --------- Signed-off-by: zouyida <[email protected]> Co-authored-by: zouyida <[email protected]>
1 parent 763ed69 commit 81f3b9c

File tree

8 files changed

+158
-25
lines changed

8 files changed

+158
-25
lines changed

docs/source/user_guide/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This is the 3rd release candidate of v0.9.1 for vLLM Ascend. Please follow the [
1919
- Fix header include issue in rope [#2398](https://github.com/vllm-project/vllm-ascend/pull/2398)
2020
- Fix mtp config bug [#2412](https://github.com/vllm-project/vllm-ascend/pull/2412)
2121
- Fix error info and adapt `attn_metedata` refactor [#2402](https://github.com/vllm-project/vllm-ascend/pull/2402)
22-
- Fix torchair runtime errror caused by configuration mismtaches and `.kv_cache_bytes` file missing [#2312](https://github.com/vllm-project/vllm-ascend/pull/2312)
22+
- Fix torchair runtime error caused by configuration mismtaches and `.kv_cache_bytes` file missing [#2312](https://github.com/vllm-project/vllm-ascend/pull/2312)
2323
- Move `with_prefill` allreduce from cpu to npu [#2230](https://github.com/vllm-project/vllm-ascend/pull/2230)
2424

2525
### Docs

examples/disaggregate_prefill_v1/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ vllm serve /models/deepseek_r1_w8a8 \
205205
Run proxy server on the first node:
206206
```shell
207207
cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1
208-
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
208+
python load_balance_proxy_server_example.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.32.175 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
209209
```
210210

211211
Verification

examples/disaggregate_prefill_v1/load_balance_proxy_server_example.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import argparse
66
import asyncio
7+
import functools
78
import heapq
89
import os
910
import sys
11+
import uuid
1012
from contextlib import asynccontextmanager
1113
from typing import List
1214

@@ -54,7 +56,6 @@ def __init__(self, prefiller_instances, decoder_instances):
5456
]
5557
self.req_to_prefiller = {}
5658
self.req_id_lock = asyncio.Lock()
57-
self.req_id_counter = 0
5859
# Removed selection locks - no longer needed for synchronous methods
5960

6061
# Initialize priority queues for efficient server selection
@@ -110,8 +111,7 @@ def aquire_aborted_prefiller_requests(
110111

111112
async def next_req_id(self):
112113
async with self.req_id_lock:
113-
self.req_id_counter += 1
114-
return str(self.req_id_counter)
114+
return str(uuid.uuid4())
115115

116116
def select_prefiller(self, token_count): # Changed to synchronous
117117
# No lock needed - entire function is atomic
@@ -230,6 +230,32 @@ async def lifespan(app: FastAPI):
230230
await d.client.aclose()
231231

232232

233+
async def listen_for_disconnect(request: Request) -> None:
234+
"""Return if a disconnect message is received"""
235+
while True:
236+
message = await request.receive()
237+
if message["type"] == "http.disconnect":
238+
break
239+
240+
241+
def with_cancellation(handler_func):
242+
243+
@functools.wraps(handler_func)
244+
async def wrapper(*args, **kwargs):
245+
request = kwargs["request"]
246+
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
247+
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
248+
done, pending = await asyncio.wait([handler_task, cancellation_task],
249+
return_when=asyncio.FIRST_COMPLETED)
250+
for task in pending:
251+
task.cancel()
252+
if handler_task in done:
253+
return handler_task.result()
254+
return None
255+
256+
return wrapper
257+
258+
233259
app = FastAPI(lifespan=lifespan)
234260

235261

@@ -410,11 +436,13 @@ async def generate_stream():
410436

411437

412438
@app.post("/v1/completions")
439+
@with_cancellation
413440
async def handle_completions(request: Request):
414441
return await _handle_completions("/completions", request)
415442

416443

417444
@app.post("/v1/chat/completions")
445+
@with_cancellation
418446
async def handle_chat_completions(request: Request):
419447
return await _handle_completions("/chat/completions", request)
420448

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def get_finished(
884884
if now < expires:
885885
break
886886
logger.warning(
887-
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
887+
f"Some requests in prefill node fail to receive KV Cache transfer done signal in {envs.VLLM_LLMDD_ABORT_REQUEST_TIMEOUT}s. "
888888
"If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
889889
)
890890
if req_id in self.reqs_to_send:

vllm_ascend/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
# `VLLM_LLMDD_ABORT_REQUEST_TIMEOUT` is only applicable when using LLMDataDistCMgrConnector in a
139139
# disaggregated decode-prefill setup.
140140
"VLLM_LLMDD_ABORT_REQUEST_TIMEOUT":
141-
lambda: int(os.getenv("VLLM_LLMDD_ABORT_REQUEST_TIMEOUT", 300)),
141+
lambda: int(os.getenv("VLLM_LLMDD_ABORT_REQUEST_TIMEOUT", 120)),
142142
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
143143
# and the mla_pa will be the default path of deepseek decode path.
144144
"VLLM_ASCEND_MLA_PA":

vllm_ascend/patch/platform/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from vllm_ascend.utils import vllm_version_is
18-
1917
# Import specific patches for different versions
20-
if vllm_version_is("0.9.1"):
21-
from vllm_ascend.patch.platform import patch_0_9_1 # noqa: F401
22-
from vllm_ascend.patch.platform import patch_common # noqa: F401
23-
else:
24-
from vllm_ascend.patch.platform import patch_common # noqa: F401
25-
from vllm_ascend.patch.platform import patch_main # noqa: F401
18+
from vllm_ascend.patch.platform import patch_0_9_1 # noqa: F401
19+
from vllm_ascend.patch.platform import patch_common # noqa: F401

vllm_ascend/patch/platform/patch_0_9_1/patch_core.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import os
22
import signal
3-
from typing import Optional
3+
import types
4+
from collections.abc import Iterable
5+
from typing import Optional, Union
46

57
from vllm.config import ParallelConfig, VllmConfig
68
from vllm.logger import init_logger
79
from vllm.transformers_utils.config import \
810
maybe_register_config_serialize_by_value
911
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
12+
from vllm.v1.outputs import ModelRunnerOutput
13+
from vllm.v1.request import RequestStatus
1014

1115
import vllm_ascend.envs as vllm_ascend_envs
1216

@@ -77,7 +81,10 @@ def run_busy_loop(self):
7781
self.execute_dummy_batch()
7882

7983

80-
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
84+
def run_engine_core_dplb(*args,
85+
dp_rank: int = 0,
86+
local_dp_rank: int = 0,
87+
**kwargs):
8188
"""Launch EngineCore busy loop in background process."""
8289

8390
# Signal handler used for graceful termination.
@@ -108,7 +115,115 @@ def signal_handler(signum, frame):
108115
engine_core = ExternealDPEngineCoreProc(*args, **kwargs)
109116
else:
110117
engine_core = EngineCoreProc(*args, **kwargs)
118+
engine_core.scheduler.finish_requests = types.MethodType(
119+
finish_requests, engine_core.scheduler)
120+
engine_core.scheduler._update_from_kv_xfer_finished = types.MethodType(
121+
_update_from_kv_xfer_finished, engine_core.scheduler)
122+
engine_core.run_busy_loop()
123+
124+
except SystemExit:
125+
logger.debug("EngineCore exiting.")
126+
raise
127+
except Exception as e:
128+
if engine_core is None:
129+
logger.exception("EngineCore failed to start.")
130+
else:
131+
logger.exception("EngineCore encountered a fatal error.")
132+
engine_core._send_engine_dead()
133+
raise e
134+
finally:
135+
if engine_core is not None:
136+
engine_core.shutdown()
137+
138+
139+
def finish_requests(
140+
self,
141+
request_ids: Union[str, Iterable[str]],
142+
finished_status: RequestStatus,
143+
) -> None:
144+
"""Handles the finish signal from outside the scheduler.
145+
For example, the API server can abort a request when the client
146+
disconnects.
147+
"""
148+
assert RequestStatus.is_finished(finished_status)
149+
if isinstance(request_ids, str):
150+
request_ids = (request_ids, )
151+
else:
152+
request_ids = set(request_ids)
153+
154+
for req_id in request_ids:
155+
request = self.requests.get(req_id)
156+
if request is None:
157+
# Invalid request ID.
158+
continue
159+
if request in self.waiting or request in self.running:
160+
if request.status == RequestStatus.RUNNING:
161+
self.running.remove(request)
162+
else:
163+
self.waiting.remove(request)
164+
request.status = finished_status
165+
self._free_request(request)
166+
167+
168+
def _update_from_kv_xfer_finished(self,
169+
model_runner_output: ModelRunnerOutput):
170+
"""
171+
KV Connector: update the scheduler state based on the output.
172+
The Worker side connectors add finished_recving and
173+
finished_sending reqs to the output.
174+
* if finished_sending: free the blocks
175+
# if finished_recving: add to state so we can
176+
scheduler the request during the next step.
177+
"""
178+
# KV Connector:: update recv and send status from last step.
179+
for req_id in (model_runner_output.finished_recving or ()):
180+
logger.debug("Finished recving KV transfer for request %s", req_id)
181+
self.finished_recving_kv_req_ids.add(req_id)
182+
for req_id in (model_runner_output.finished_sending or ()):
183+
logger.debug("Finished sending KV transfer for request %s", req_id)
184+
if req_id in self.requests:
185+
self._free_blocks(self.requests[req_id])
186+
else:
187+
logger.debug("cannot find the req_id it may have been aborted.%s",
188+
req_id)
189+
190+
191+
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
192+
"""Launch EngineCore busy loop in background process."""
193+
194+
# Signal handler used for graceful termination.
195+
# SystemExit exception is only raised once to allow this and worker
196+
# processes to terminate without error
197+
shutdown_requested = False
198+
199+
# Ensure we can serialize transformer config after spawning
200+
maybe_register_config_serialize_by_value()
201+
202+
def signal_handler(signum, frame):
203+
nonlocal shutdown_requested
204+
if not shutdown_requested:
205+
shutdown_requested = True
206+
raise SystemExit()
207+
208+
# Either SIGTERM or SIGINT will terminate the engine_core
209+
signal.signal(signal.SIGTERM, signal_handler)
210+
signal.signal(signal.SIGINT, signal_handler)
211+
212+
engine_core: Optional[EngineCoreProc] = None
213+
try:
214+
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
215+
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
216+
# Set data parallel rank for this engine process.
217+
parallel_config.data_parallel_rank = dp_rank
218+
parallel_config.data_parallel_rank_local = local_dp_rank
219+
engine_core = DPEngineCoreProc(*args, **kwargs)
220+
else:
221+
engine_core = EngineCoreProc(*args, **kwargs)
111222

223+
engine_core.scheduler.finish_requests = types.MethodType(
224+
finish_requests, engine_core.scheduler)
225+
engine_core.scheduler._update_from_kv_xfer_finished = types.MethodType(
226+
_update_from_kv_xfer_finished, engine_core.scheduler)
112227
engine_core.run_busy_loop()
113228

114229
except SystemExit:
@@ -129,4 +244,6 @@ def signal_handler(signum, frame):
129244
# Apply this patch only if the external data parallelism is enabled
130245
if vllm_ascend_envs.VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED:
131246
# Patch the EngineCoreClient to use the custom make_async_mp_client
247+
EngineCoreProc.run_engine_core = run_engine_core_dplb # type: ignore[attr-defined]
248+
else:
132249
EngineCoreProc.run_engine_core = run_engine_core # type: ignore[attr-defined]

vllm_ascend/patch/worker/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
from vllm_ascend.utils import vllm_version_is
19-
2018
# Import specific patches for different versions
21-
if vllm_version_is("0.9.1"):
22-
from vllm_ascend.patch.worker import patch_0_9_1 # noqa: F401
23-
from vllm_ascend.patch.worker import patch_common # noqa: F401
24-
else:
25-
from vllm_ascend.patch.worker import patch_common # noqa: F401
26-
from vllm_ascend.patch.worker import patch_main # noqa: F401
19+
from vllm_ascend.patch.worker import patch_0_9_1 # noqa: F401
20+
from vllm_ascend.patch.worker import patch_common # noqa: F401

0 commit comments

Comments
 (0)