Skip to content

Commit f4e4b93

Browse files
committed
Update mock_server.py
1 parent 0a2b6d7 commit f4e4b93

File tree

1 file changed

+62
-33
lines changed

1 file changed

+62
-33
lines changed

tests/mock_server.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import threading
77
import multiprocessing
8+
import asyncio
89
from typing import List, Optional, Dict, Any, Union, AsyncGenerator, Tuple
910
from contextlib import asynccontextmanager
1011
from dataclasses import dataclass
@@ -76,7 +77,7 @@ def __init__(self):
7677
self.request_queue: Optional[multiprocessing.Queue] = None
7778
self.response_queue: Optional[multiprocessing.Queue] = None
7879
self.active_requests: int = 0
79-
self.lock = threading.Lock()
80+
self._lock: Optional[asyncio.Lock] = None
8081
self.is_mock_mode = False
8182

8283
@classmethod
@@ -85,28 +86,60 @@ def get_instance(cls):
8586
cls._instance = ServerState()
8687
return cls._instance
8788

89+
@property
90+
def lock(self) -> asyncio.Lock:
91+
"""Lazy initialization of the lock to ensure it attaches to the running loop."""
92+
if self._lock is None:
93+
self._lock = asyncio.Lock()
94+
return self._lock
95+
8896
def set_queues(self, req_q, res_q):
8997
self.request_queue = req_q
9098
self.response_queue = res_q
9199
self.is_mock_mode = True
92100
logger.warning("!!! Server running in MOCK MODE via Queue Injection !!!")
93101

94-
def increment_request(self):
95-
with self.lock:
102+
async def increment_request(self):
103+
async with self.lock:
96104
self.active_requests += 1
97105

98-
def decrement_request(self):
99-
with self.lock:
106+
async def decrement_request(self):
107+
async with self.lock:
100108
self.active_requests -= 1
101109

102-
@property
103-
def snapshot(self):
104-
with self.lock:
110+
async def get_snapshot(self) -> Dict[str, Any]:
111+
"""Async safe snapshot retrieval."""
112+
async with self.lock:
105113
return {
106114
"active_requests": self.active_requests,
107115
"is_mock_mode": self.is_mock_mode,
108116
}
109117

118+
async def mock_request_interaction(
119+
self, request_body: Dict[str, Any], timeout: float = 10.0
120+
) -> Dict[str, Any]:
121+
"""
122+
Handles the interaction with the multiprocessing Queue in a non-blocking way.
123+
This fixes the blocking issue by offloading queue.get/put to the executor.
124+
"""
125+
loop = asyncio.get_running_loop()
126+
127+
# 1. Offload the blocking put operation
128+
await loop.run_in_executor(None, self.request_queue.put, request_body)
129+
130+
# 2. Offload the blocking get operation
131+
# We use a lambda to pass the timeout to the queue.get method
132+
def blocking_get():
133+
return self.response_queue.get(timeout=timeout)
134+
135+
response_data = await loop.run_in_executor(None, blocking_get)
136+
137+
return (
138+
json.loads(response_data)
139+
if isinstance(response_data, str)
140+
else response_data
141+
)
142+
110143

111144
SERVER_STATE = ServerState.get_instance()
112145

@@ -943,22 +976,23 @@ def _format_unary_response(
943976

944977
@asynccontextmanager
945978
async def lifespan(app: FastAPI):
946-
stop_event = threading.Event()
947-
948-
def epoch_clock():
949-
while not stop_event.is_set():
950-
time.sleep(5)
951-
state = SERVER_STATE.snapshot
979+
async def epoch_clock():
980+
while True:
981+
await asyncio.sleep(5)
982+
state = await SERVER_STATE.get_snapshot()
952983
if state["active_requests"] > 0 or state["is_mock_mode"]:
953984
logger.info(
954985
f"[Monitor] Active: {state['active_requests']} | "
955986
f"Mode: {'MOCK' if state['is_mock_mode'] else 'PRODUCTION'}"
956987
)
957988

958-
monitor_thread = threading.Thread(target=epoch_clock, daemon=True)
959-
monitor_thread.start()
989+
monitor_task = asyncio.create_task(epoch_clock())
960990
yield
961-
stop_event.set()
991+
monitor_task.cancel()
992+
try:
993+
await monitor_task
994+
except asyncio.CancelledError:
995+
pass
962996

963997

964998
def _prepare_proxy_and_headers(
@@ -1148,13 +1182,15 @@ async def validation_exception_handler(request, exc):
11481182

11491183
@app.middleware("http")
11501184
async def request_tracker(request: Request, call_next):
1151-
SERVER_STATE.increment_request()
1185+
# Await the async increment
1186+
await SERVER_STATE.increment_request()
11521187
start_time = time.time()
11531188
try:
11541189
response = await call_next(request)
11551190
return response
11561191
finally:
1157-
SERVER_STATE.decrement_request()
1192+
# Await the async decrement
1193+
await SERVER_STATE.decrement_request()
11581194
duration = (time.time() - start_time) * 1000
11591195
if duration > 1000:
11601196
logger.warning(
@@ -1203,19 +1239,15 @@ async def generation(
12031239
if body:
12041240
try:
12051241
if _MOCK_ENV_API_KEY:
1206-
shadow_proxy = DeepSeekProxy(api_key=_MOCK_ENV_API_KEY)
1242+
# Just to ensure init consistency even if not used
1243+
_ = DeepSeekProxy(api_key=_MOCK_ENV_API_KEY)
12071244
except Exception:
12081245
pass
12091246

12101247
try:
12111248
raw_body = await request.json()
1212-
SERVER_STATE.request_queue.put(raw_body)
1213-
response_data = SERVER_STATE.response_queue.get(timeout=10)
1214-
response_json = (
1215-
json.loads(response_data)
1216-
if isinstance(response_data, str)
1217-
else response_data
1218-
)
1249+
# Use the new async-safe mock interaction
1250+
response_json = await SERVER_STATE.mock_request_interaction(raw_body)
12191251
status_code = response_json.pop("status_code", 200)
12201252
return JSONResponse(content=response_json, status_code=status_code)
12211253
except Exception as e:
@@ -1270,12 +1302,9 @@ async def catch_all(path_name: str, request: Request):
12701302
"headers": dict(request.headers),
12711303
"body": body,
12721304
}
1273-
SERVER_STATE.request_queue.put(req_record)
1274-
response_data = SERVER_STATE.response_queue.get(timeout=5)
1275-
response_json = (
1276-
json.loads(response_data)
1277-
if isinstance(response_data, str)
1278-
else response_data
1305+
# Use the new async-safe mock interaction
1306+
response_json = await SERVER_STATE.mock_request_interaction(
1307+
req_record, timeout=5.0
12791308
)
12801309
status_code = response_json.pop("status_code", 200)
12811310
return JSONResponse(content=response_json, status_code=status_code)

0 commit comments

Comments
 (0)