55import logging
66import threading
77import multiprocessing
8+ import asyncio
89from typing import List , Optional , Dict , Any , Union , AsyncGenerator , Tuple
910from contextlib import asynccontextmanager
1011from dataclasses import dataclass
@@ -76,7 +77,7 @@ def __init__(self):
7677 self .request_queue : Optional [multiprocessing .Queue ] = None
7778 self .response_queue : Optional [multiprocessing .Queue ] = None
7879 self .active_requests : int = 0
79- self .lock = threading .Lock ()
80+ self ._lock : Optional [ asyncio .Lock ] = None
8081 self .is_mock_mode = False
8182
8283 @classmethod
@@ -85,28 +86,60 @@ def get_instance(cls):
8586 cls ._instance = ServerState ()
8687 return cls ._instance
8788
89+ @property
90+ def lock (self ) -> asyncio .Lock :
91+ """Lazy initialization of the lock to ensure it attaches to the running loop."""
92+ if self ._lock is None :
93+ self ._lock = asyncio .Lock ()
94+ return self ._lock
95+
8896 def set_queues (self , req_q , res_q ):
8997 self .request_queue = req_q
9098 self .response_queue = res_q
9199 self .is_mock_mode = True
92100 logger .warning ("!!! Server running in MOCK MODE via Queue Injection !!!" )
93101
94- def increment_request (self ):
95- with self .lock :
102+ async def increment_request (self ):
103+ async with self .lock :
96104 self .active_requests += 1
97105
98- def decrement_request (self ):
99- with self .lock :
106+ async def decrement_request (self ):
107+ async with self .lock :
100108 self .active_requests -= 1
101109
102- @ property
103- def snapshot ( self ):
104- with self .lock :
110+ async def get_snapshot ( self ) -> Dict [ str , Any ]:
111+ """Async safe snapshot retrieval."""
112+ async with self .lock :
105113 return {
106114 "active_requests" : self .active_requests ,
107115 "is_mock_mode" : self .is_mock_mode ,
108116 }
109117
118+ async def mock_request_interaction (
119+ self , request_body : Dict [str , Any ], timeout : float = 10.0
120+ ) -> Dict [str , Any ]:
121+ """
122+ Handles the interaction with the multiprocessing Queue in a non-blocking way.
123+ This fixes the blocking issue by offloading queue.get/put to the executor.
124+ """
125+ loop = asyncio .get_running_loop ()
126+
127+ # 1. Offload the blocking put operation
128+ await loop .run_in_executor (None , self .request_queue .put , request_body )
129+
130+ # 2. Offload the blocking get operation
131+ # We use a lambda to pass the timeout to the queue.get method
132+ def blocking_get ():
133+ return self .response_queue .get (timeout = timeout )
134+
135+ response_data = await loop .run_in_executor (None , blocking_get )
136+
137+ return (
138+ json .loads (response_data )
139+ if isinstance (response_data , str )
140+ else response_data
141+ )
142+
110143
111144SERVER_STATE = ServerState .get_instance ()
112145
@@ -943,22 +976,23 @@ def _format_unary_response(
943976
944977@asynccontextmanager
945978async def lifespan (app : FastAPI ):
946- stop_event = threading .Event ()
947-
948- def epoch_clock ():
949- while not stop_event .is_set ():
950- time .sleep (5 )
951- state = SERVER_STATE .snapshot
979+ async def epoch_clock ():
980+ while True :
981+ await asyncio .sleep (5 )
982+ state = await SERVER_STATE .get_snapshot ()
952983 if state ["active_requests" ] > 0 or state ["is_mock_mode" ]:
953984 logger .info (
954985 f"[Monitor] Active: { state ['active_requests' ]} | "
955986 f"Mode: { 'MOCK' if state ['is_mock_mode' ] else 'PRODUCTION' } "
956987 )
957988
958- monitor_thread = threading .Thread (target = epoch_clock , daemon = True )
959- monitor_thread .start ()
989+ monitor_task = asyncio .create_task (epoch_clock ())
960990 yield
961- stop_event .set ()
991+ monitor_task .cancel ()
992+ try :
993+ await monitor_task
994+ except asyncio .CancelledError :
995+ pass
962996
963997
964998def _prepare_proxy_and_headers (
@@ -1148,13 +1182,15 @@ async def validation_exception_handler(request, exc):
11481182
11491183 @app .middleware ("http" )
11501184 async def request_tracker (request : Request , call_next ):
1151- SERVER_STATE .increment_request ()
1185+ # Await the async increment
1186+ await SERVER_STATE .increment_request ()
11521187 start_time = time .time ()
11531188 try :
11541189 response = await call_next (request )
11551190 return response
11561191 finally :
1157- SERVER_STATE .decrement_request ()
1192+ # Await the async decrement
1193+ await SERVER_STATE .decrement_request ()
11581194 duration = (time .time () - start_time ) * 1000
11591195 if duration > 1000 :
11601196 logger .warning (
@@ -1203,19 +1239,15 @@ async def generation(
12031239 if body :
12041240 try :
12051241 if _MOCK_ENV_API_KEY :
1206- shadow_proxy = DeepSeekProxy (api_key = _MOCK_ENV_API_KEY )
1242+ # Just to ensure init consistency even if not used
1243+ _ = DeepSeekProxy (api_key = _MOCK_ENV_API_KEY )
12071244 except Exception :
12081245 pass
12091246
12101247 try :
12111248 raw_body = await request .json ()
1212- SERVER_STATE .request_queue .put (raw_body )
1213- response_data = SERVER_STATE .response_queue .get (timeout = 10 )
1214- response_json = (
1215- json .loads (response_data )
1216- if isinstance (response_data , str )
1217- else response_data
1218- )
1249+ # Use the new async-safe mock interaction
1250+ response_json = await SERVER_STATE .mock_request_interaction (raw_body )
12191251 status_code = response_json .pop ("status_code" , 200 )
12201252 return JSONResponse (content = response_json , status_code = status_code )
12211253 except Exception as e :
@@ -1270,12 +1302,9 @@ async def catch_all(path_name: str, request: Request):
12701302 "headers" : dict (request .headers ),
12711303 "body" : body ,
12721304 }
1273- SERVER_STATE .request_queue .put (req_record )
1274- response_data = SERVER_STATE .response_queue .get (timeout = 5 )
1275- response_json = (
1276- json .loads (response_data )
1277- if isinstance (response_data , str )
1278- else response_data
1305+ # Use the new async-safe mock interaction
1306+ response_json = await SERVER_STATE .mock_request_interaction (
1307+ req_record , timeout = 5.0
12791308 )
12801309 status_code = response_json .pop ("status_code" , 200 )
12811310 return JSONResponse (content = response_json , status_code = status_code )
0 commit comments