diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 75586d09a3e..f486ad39bd7 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -44,11 +44,9 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( - CompletionOutput, ControlRequest, ControlResponse, Request, - RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -142,6 +140,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.is_paused = False # pause request generation self._pause_cond = threading.Condition() + self._rejecting_new_requests = False # blocks new requests during abort drain self._ctrl_output_queues = {} self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) @@ -1305,10 +1304,26 @@ def _insert_zmq_task_to_scheduler(self): self.request_worker_map[req_id_for_map] = worker_pid status_value = data.get("status", None) if status_value is not None and status_value == RequestStatus.ABORT.value: - req_id = data["request_id"] - self.llm_logger.info(f"Receive abort request, req_id: {req_id}") - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.add_abort_req_ids(req_id) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.llm_logger.info("abort requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + else: + abort_all = data.get("abort_all", False) + req_ids = data.get("req_ids", []) + if abort_all or req_ids: + target_req_ids = self._resolve_abort_targets(abort_all, req_ids) + self.llm_logger.info( + f"Receive abort_reqs, abort_all={abort_all}, " + f"input={len(req_ids)}, resolved={len(target_req_ids)}" + ) + self.resource_manager.add_abort_req_ids(target_req_ids) + else: + req_id = data.get("request_id", None) + if not req_id: + self.llm_logger.warning( + "Receive abort request without request_id, skip invalid abort message" + ) + self.llm_logger.info(f"Receive abort request, req_id: {req_id}") + self.resource_manager.add_abort_req_ids(req_id) continue err_msg = None try: @@ -1325,7 +1340,7 @@ def _insert_zmq_task_to_scheduler(self): trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) self.llm_logger.debug(f"Receive request from api server: {request}") - if self.is_paused: + if self.is_paused or self._rejecting_new_requests: self.llm_logger.warning(f"Engine is paused, drop request: {request}") self._send_error_response( request.request_id, @@ -1445,39 +1460,19 @@ def _control_pause(self, control_request: ControlRequest): if self.is_paused: self.llm_logger.info("Engine is already paused, no need to pause again.") return - self.is_paused = True - - self.llm_logger.info("Abort running requests.") - - self.resource_manager.log_status() - # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue - timeout, count = 60, 0 - while self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - count += 1 - if count >= timeout * 1000: - break - if count >= timeout * 1000: - error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" - self.llm_logger.error(error_msg) - raise Exception(error_msg) - running_reqs = self.resource_manager.preempted_all() - if len(running_reqs) > 0: - self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") - self.resource_manager.get_real_bsz() - self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) - self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) - # self.engine_worker_queue.clear_data() - self.token_processor.clear_data() + self._rejecting_new_requests = True self.resource_manager.log_status() - # abort inflight requests to user - inflight_requests = self.scheduler.get_inflight_requests() - self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") - for req in inflight_requests: - self._send_error_response(req.request_id, "Request is aborted since engine is paused.") - self.scheduler.reset() + all_req_ids = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + self.llm_logger.info(f"Pause: aborting {len(all_req_ids)} total requests.") + if all_req_ids: + self.resource_manager.add_abort_req_ids(all_req_ids) + self._wait_inflight_drained() + with self._pause_cond: + self.is_paused = True + + self.resource_manager.log_status() if envs.ENABLE_V1_KVCACHE_MANAGER: self.resource_manager.cache_manager.reset_cache() else: @@ -1500,6 +1495,16 @@ def _control_pause(self, control_request: ControlRequest): self.llm_logger.info("Successfully paused request generation.") return None + def _wait_inflight_drained(self): + """ + Wait until resource_manager.requests is completely empty. + No timeout — abort pipeline will complete. Aligned with SGLang's poll-until-drained. + """ + start_time = time.time() + while self.resource_manager.requests or self.scheduler.requests: + time.sleep(0.005) + self.llm_logger.info(f"All inflight requests drained, takes {time.time() - start_time:.1f} seconds.") + def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. @@ -1514,6 +1519,7 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: if not self.is_paused: self.llm_logger.info("Engine is not paused, no need to resume.") return None + self._rejecting_new_requests = False self.is_paused = False self._pause_cond.notify_all() @@ -1597,150 +1603,6 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d return responses - def _control_abort_requests(self, control_req: ControlRequest): - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") - args = control_req.get_args() - abort_all = args.get("abort_all", False) - req_ids = args.get("req_ids", []) - matched_input_ids = set() - now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) - - # Step 1: Determine target request list - if abort_all: - # all requests in running + waiting - target_req_ids = now_reqs - else: - # filter out requests that actually exist - target_req_ids = [] - for rid in req_ids: - if rid in now_reqs: - target_req_ids.append(rid) - matched_input_ids.add(rid) - elif f"{rid}_0" in now_reqs: - target_req_ids.append(f"{rid}_0") - matched_input_ids.add(rid) - - if not target_req_ids: - return {"aborted": [], "not_found": req_ids if not abort_all else []} - - # Step 2: Collect partial results - aborted_info = [] - results = [] - for req_id in target_req_ids: - request = self.resource_manager.requests.get(req_id) - if request is None: - scheduled_req = self.scheduler.requests.get(req_id) - if scheduled_req is None: - continue - request = scheduled_req.raw - - partial_token_ids = list(request.output_token_ids) - - # Construct finished response with partial results - now = time.time() - abort_metrics = RequestMetrics( - arrival_time=request.metrics.arrival_time if request.metrics else now, - inference_start_time=request.metrics.inference_start_time if request.metrics else now, - engine_recv_latest_token_time=now, - engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, - request_start_time=request.metrics.arrival_time if request.metrics else now, - ) - eos_token_ids = getattr(request, "eos_token_ids", [0]) - result = RequestOutput( - request_id=req_id, - finished=True, - outputs=CompletionOutput( - index=0, - send_idx=len(partial_token_ids), - token_ids=[eos_token_ids[0]], - ), - metrics=abort_metrics, - error_code=200, - error_msg="Aborted", - ) - results.append(result) - aborted_info.append( - { - "request_id": req_id, - "output_token_count": len(partial_token_ids), - } - ) - - # Step 3: Execute abort — add all requests to waiting_abort_req_id_set - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - for req_id in target_req_ids: - self.resource_manager.add_abort_req_ids(req_id) - time.sleep(0.0001) - if self.cfg.scheduler_config.splitwise_role != "prefill": - self._wait_abort_complete(target_req_ids) - - # Add results to scheduler, engine will have a thread calling get_results, - # then cleanup and call send_response to send to client. - # When client disconnects, send_response will automatically ignore - if self.cfg.scheduler_config.splitwise_role != "prefill": - try: - # self.send_response_server.send_response(req_id, [result]) - self.scheduler.put_results(results) - except Exception: - pass # client may have disconnected - - not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] - - return {"aborted": aborted_info, "not_found": not_found} - - def _wait_abort_complete(self, target_req_ids, stall_timeout=1): - """ - Wait for all abort requests to complete. - - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet - - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, - reset progress state if any, then continue monitoring - """ - target_set = set(target_req_ids) - target_set = target_set & (set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) - prev_remaining_count = len(target_set) - last_progress_time = time.time() - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - while remaining: - alive_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) - finished_reqs = target_set - alive_reqs - if finished_reqs: - self.llm_logger.info(f"abort targets already finished, skip: {finished_reqs}") - for req_id in finished_reqs: - self.resource_manager.waiting_abort_req_id_set.discard(req_id) - self.resource_manager.to_be_aborted_req_id_set.discard(req_id) - target_set -= finished_reqs - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - if not remaining: - self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") - return - self.llm_logger.debug(f"remaining:{remaining}") - - current_count = len(remaining) - if current_count < prev_remaining_count: - # progress made: recycle_abort_task was called - self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") - last_progress_time = time.time() - prev_remaining_count = current_count - - if time.time() - last_progress_time > stall_timeout: - # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) - stuck = remaining & self.resource_manager.to_be_aborted_req_id_set - if stuck: - self.llm_logger.warning( - f"no abort progress for {stall_timeout}s, " - f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" - ) - for req_id in list(stuck): - self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") - self.resource_manager.recycle_abort_task(req_id) - # reset progress state - last_progress_time = time.time() - prev_remaining_count = current_count - len(stuck) - # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow - - time.sleep(0.005) - def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. @@ -2766,3 +2628,21 @@ def detect_thread(): except Exception: pass return True + + def _resolve_abort_targets(self, abort_all, req_ids): + """ + Resolve abort target request IDs. + """ + now_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) + self.llm_logger.debug(f"now_reqs: {now_reqs}") + + if abort_all: + return list(now_reqs) + + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + return target_req_ids diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e3d20cc7d02..63b3dafb0cf 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -325,7 +325,7 @@ def recycle_abort_task(self, request_id): del self.req_dict[request_id] self.to_be_aborted_req_id_set.discard(request_id) self.waiting_abort_req_id_set.discard(request_id) - llm_logger.debug(f"request_id:{request_id} recycle end") + llm_logger.debug(f"request_id:{request_id} recycle abort task end") self.update_metrics() def _trigger_abort(self, request_id, batch_request): @@ -338,6 +338,7 @@ def _trigger_abort(self, request_id, batch_request): batch_request.add_request(self._prepare_abort_task(abort_request)) self.to_be_aborted_req_id_set.add(request_id) self.waiting_abort_req_id_set.discard(request_id) + llm_logger.debug(f"request_id:{request_id} trigger abort") def _info_each_block(self): """ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 74dd79af1e3..c8aec78e958 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -1130,6 +1130,18 @@ async def abort(self, request_id, n=1) -> None: request_ids=",".join(request_ids), ) + async def abort_reqs(self, req_ids=None, abort_all=False): + """ + Fire-and-forget: abort multiple requests in one ZMQ message. + Used by /v1/abort_requests API. + """ + data = { + "status": RequestStatus.ABORT.value, + "abort_all": abort_all, + "req_ids": req_ids or [], + } + self._send_task(data) + def process_messages(self, messages): for message in messages: if message["role"] == "assistant" and "tool_calls" in message: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 9160ee3e093..a4de9cf9375 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -496,13 +496,8 @@ async def abort_requests(request: Request): if not abort_all and not req_ids: return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) - control_request = ControlRequest( - request_id=f"control-{uuid.uuid4()}", - method="abort_requests", - args={"abort_all": abort_all, "req_ids": req_ids or []}, - ) - control_response = await app.state.engine_client.run_control_method(control_request) - return control_response.to_api_json_response() + await app.state.engine_client.abort_reqs(req_ids=req_ids or [], abort_all=abort_all) + return Response(status_code=200) def wrap_streaming_generator(original_generator: AsyncGenerator): diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index 0954568111e..ffaaf0f4aa5 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -89,7 +89,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ decode_type = request_output["outputs"].get("decode_type", 0) or 0 if decode_type == 0: # text tts = req_id in self._audio_buffer - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: all_audio_tokens = self._audio_buffer.pop(req_id, []) else: all_audio_tokens = None @@ -186,7 +186,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ else: self.accumulate_token_ids(request_output) token_ids = request_output["outputs"]["token_ids"] - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: multipart = [] num_image_tokens = 0 for part in self._multipart_buffer: diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index f6253b0ad89..8e057b6657a 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -18,7 +18,7 @@ import aiohttp import uvicorn from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse +from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -29,6 +29,7 @@ from fastdeploy.utils import router_logger as logger app = FastAPI() +_background_tasks = set() @dataclass @@ -588,39 +589,15 @@ async def abort_requests(request: Request): decode_servers = app.state.router.decode_servers all_servers = prefill_servers + decode_servers - async with aiohttp.ClientSession() as session: - tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] - responses = await asyncio.gather(*tasks, return_exceptions=True) - - # Aggregate results from Node D only - all_aborted = [] - all_not_found = [] - errors = [] - decode_start = len(prefill_servers) - for i, (server, resp) in enumerate(zip(all_servers, responses)): - if i < decode_start: - continue - if isinstance(resp, Exception): - errors.append({"server": server.url(), "error": str(resp)}) - elif resp.status == 200: - data = await resp.json() - result = data.get("result") or {} - all_aborted.extend(result.get("aborted", [])) - all_not_found.extend(result.get("not_found", [])) - else: - errors.append({"server": server.url(), "status": resp.status}) - - return JSONResponse( - content={ - "request_id": f"router-{uuid4()}", - "status": "success" if not errors else "error", - "error_message": None if not errors else str(errors), - "result": { - "aborted": all_aborted, - "not_found": list(set(all_not_found)), - }, - } - ) + async def _forward_abort(): + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + await asyncio.gather(*tasks, return_exceptions=True) + + task = asyncio.create_task(_forward_abort()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + return Response(status_code=200) def launch_router(router_args: RouterArgs): diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index ac30f26d9ab..0e9a1bca4eb 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1137,22 +1137,29 @@ def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False eng._pause_cond = threading.Condition() - eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False)) eng.resource_manager = Mock( - preempted_all=Mock(return_value=[Request(request_id="r1", prompt_token_ids=[1], prompt_token_ids_len=1)]), - get_real_bsz=Mock(), - wait_worker_inflight_requests_finish=Mock(), + requests={"r1": Mock(output_token_ids=[1, 2, 3])}, + waiting_abort_req_id_set=set(), + to_be_aborted_req_id_set=set(), + add_abort_req_ids=Mock(), log_status=Mock(), cache_manager=Mock(reset=Mock()), - real_bsz=1, ) eng.token_processor = Mock(clear_data=Mock()) - eng.scheduler = Mock(get_inflight_requests=Mock(return_value=[]), reset=Mock()) + mock_scheduler = Mock(reset=Mock()) + mock_scheduler.requests = {} + mock_scheduler.mutex = threading.Lock() + mock_scheduler.responses = {} + mock_scheduler.batch_responses_per_step = [] + eng.scheduler = mock_scheduler eng._send_error_response = Mock() + eng._wait_inflight_drained = Mock() with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", True): eng._control_pause(ControlRequest(request_id="ctrl1", method="pause")) self.assertTrue(eng.is_paused) + eng.resource_manager.add_abort_req_ids.assert_called_once() eng._control_resume(ControlRequest(request_id="ctrl2", method="resume")) self.assertFalse(eng.is_paused) @@ -3530,7 +3537,7 @@ def _fake_sleep(s): self.assertGreaterEqual(call_count[0], 1) self._detach_finalizer(eng) - # ── _control_abort_requests / _wait_abort_complete ─────────────── + # ── _resolve_abort_targets / _build_abort_results ─────────────── def _make_abort_engine(self, splitwise_role="mixed"): """Create an engine wired up for abort tests.""" @@ -3571,42 +3578,17 @@ def _make_fake_request(self, output_token_ids=None): req.metrics.engine_recv_first_token_time = 1000.2 return req - def test_control_abort_requests_not_v1_raises(self): - """abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off.""" - eng = self._make_abort_engine() - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): - with self.assertRaises(Exception) as ctx: - eng._control_abort_requests(control_req) - self.assertIn("only supported", str(ctx.exception)) - self._detach_finalizer(eng) - - def test_control_abort_requests_abort_all(self): - """abort_all=True aborts all requests in resource_manager + scheduler.""" + def test_resolve_abort_targets_abort_all(self): + """abort_all=True returns all requests in resource_manager + scheduler.""" eng = self._make_abort_engine() eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])} eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))} - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - # Simulate immediate abort completion - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 2) - self.assertEqual(result["not_found"], []) - ids = {a["request_id"] for a in result["aborted"]} - self.assertEqual(ids, {"req-1_0", "req-2_0"}) - # put_results should have been called (not prefill) - eng.scheduler.put_results.assert_called_once() + target = eng._resolve_abort_targets(abort_all=True, req_ids=[]) + self.assertEqual(set(target), {"req-1_0", "req-2_0"}) self._detach_finalizer(eng) - def test_control_abort_requests_by_req_ids_with_suffix_match(self): + def test_resolve_abort_targets_by_req_ids_with_suffix_match(self): """req_ids match both exact and _0 suffix.""" eng = self._make_abort_engine() eng.resource_manager.requests = { @@ -3614,136 +3596,18 @@ def test_control_abort_requests_by_req_ids_with_suffix_match(self): "req-B": self._make_fake_request([4, 5]), } - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["req-A", "req-B", "req-C"], - }, - ) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - aborted_ids = {a["request_id"] for a in result["aborted"]} - self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix - self.assertIn("req-B", aborted_ids) # exact match - self.assertEqual(result["not_found"], ["req-C"]) - self._detach_finalizer(eng) - - def test_control_abort_requests_no_match(self): - """No requests found returns empty aborted and all in not_found.""" - eng = self._make_abort_engine() - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["nonexistent"], - }, - ) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"], []) - self.assertEqual(result["not_found"], ["nonexistent"]) - self._detach_finalizer(eng) - - def test_control_abort_requests_prefill_skips_wait_and_put(self): - """Prefill role skips _wait_abort_complete and put_results.""" - eng = self._make_abort_engine(splitwise_role="prefill") - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - eng.resource_manager.add_abort_req_ids = MagicMock() - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 1) - eng.scheduler.put_results.assert_not_called() - self._detach_finalizer(eng) - - def test_control_abort_requests_output_token_count(self): - """output_token_count reflects partial_token_ids length.""" - eng = self._make_abort_engine() - eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"][0]["output_token_count"], 5) - self._detach_finalizer(eng) - - def test_wait_abort_complete_immediate(self): - """_wait_abort_complete returns immediately when all requests already cleaned.""" - eng = self._make_abort_engine() - # Empty abort sets → remaining is empty → returns immediately - eng._wait_abort_complete(["req-1_0"]) - self._detach_finalizer(eng) - - def test_wait_abort_complete_progress(self): - """_wait_abort_complete exits when background thread cleans up.""" - eng = self._make_abort_engine() - eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} - # Add the request to requests dict so it won't be filtered out - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - call_count = [0] - - def fake_sleep(s): - call_count[0] += 1 - # Simulate background thread cleaning up after first sleep - eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0") - - with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep): - eng._wait_abort_complete(["req-1_0"]) - - self.assertGreaterEqual(call_count[0], 1) + target = eng._resolve_abort_targets(abort_all=False, req_ids=["req-A", "req-B", "req-C"]) + self.assertIn("req-A_0", target) # matched via _0 suffix + self.assertIn("req-B", target) # exact match + self.assertNotIn("req-C", target) + self.assertNotIn("req-C_0", target) self._detach_finalizer(eng) - def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): - """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" + def test_resolve_abort_targets_no_match(self): + """No matching request ids returns empty list.""" eng = self._make_abort_engine() - eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} - # Add the request to requests dict so it won't be filtered out - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - def mock_recycle(req_id): - eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) - - eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle) - - # Make time.time() advance past stall_timeout - time_values = [100.0, 100.0, 102.0, 102.0, 102.0] - time_idx = [0] - - def fake_time(): - idx = min(time_idx[0], len(time_values) - 1) - time_idx[0] += 1 - return time_values[idx] - - with ( - patch("fastdeploy.engine.common_engine.time.time", fake_time), - patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None), - ): - eng._wait_abort_complete(["req-1_0"], stall_timeout=1) - - eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0") + target = eng._resolve_abort_targets(abort_all=False, req_ids=["nonexistent"]) + self.assertEqual(target, []) self._detach_finalizer(eng) diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 11b4a2df4bc..94f9d641bc2 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -828,44 +828,30 @@ def _mock_abort_control_response(api_server, result, status_code=200): async def test_abort_requests_with_req_ids(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [{"request_id": "req-1_0", "output_token_count": 10}], - "not_found": ["req-999"], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.method == "abort_requests" - assert control_req.args["req_ids"] == ["req-1", "req-999"] - assert control_req.args["abort_all"] is False + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["req_ids"] == ["req-1", "req-999"] + assert call_kwargs["abort_all"] is False @pytest.mark.asyncio async def test_abort_requests_with_abort_all(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [ - {"request_id": "req-1_0", "output_token_count": 5}, - {"request_id": "req-2_0", "output_token_count": 12}, - ], - "not_found": [], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"abort_all": True}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.args["abort_all"] is True - assert control_req.args["req_ids"] == [] + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["abort_all"] is True + assert call_kwargs["req_ids"] == [] @pytest.mark.asyncio diff --git a/tests/router/test_router.py b/tests/router/test_router.py index 3ebebb72b0a..18c75d5d68a 100644 --- a/tests/router/test_router.py +++ b/tests/router/test_router.py @@ -20,6 +20,7 @@ We mock it at the network boundary to test Router's registration and selection logic. """ +import asyncio import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -192,7 +193,7 @@ async def _coro(): @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health): - """P and D both receive the request, but only D results are aggregated.""" + """Router returns 200 immediately and forwards to all (P + D) servers in background.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -203,24 +204,8 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 200 - decode_resp.json = AsyncMock( - return_value={ - "request_id": "control-d", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []}, - } - ) mock_session = self._make_mock_session([prefill_resp, decode_resp]) mock_request = AsyncMock() @@ -228,18 +213,17 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + # Give the background task a chance to run + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(len(body["result"]["aborted"]), 1) - self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15) - self.assertEqual(body["status"], "success") + self.assertEqual(resp.status_code, 200) + # Forwarded to both prefill + decode self.assertEqual(mock_session.post.call_count, 2) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_error_returns_error_status(self, mock_health): - """When D node returns a non-200 status, status should be 'error'.""" + async def test_abort_returns_200_even_when_decode_errors(self, mock_health): + """Router fire-and-forgets: still returns 200 when D returns non-200.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -250,14 +234,6 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 500 @@ -267,16 +243,14 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIsNotNone(body["error_message"]) + self.assertEqual(resp.status_code, 200) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_exception_returns_error(self, mock_health): - """When D node connection fails (exception), error should be captured.""" + async def test_abort_returns_200_when_decode_raises(self, mock_health): + """Router fire-and-forgets: still returns 200 when a downstream raises.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -287,30 +261,20 @@ async def test_abort_decode_exception_returns_error(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) - - # D node raises exception — but asyncio.gather(return_exceptions=True) captures it - # So we pass the exception as a response directly + mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder call_idx = [0] def post_with_exception(*args, **kwargs): call_idx[0] += 1 if call_idx[0] == 1: - # prefill: normal + async def _coro(): return prefill_resp return _coro() else: - # decode: raise (gather with return_exceptions=True will catch) + async def _coro_err(): raise ConnectionError("refused") @@ -322,12 +286,10 @@ async def _coro_err(): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIn("refused", body["error_message"]) + self.assertEqual(resp.status_code, 200) if __name__ == "__main__":