1313from fastapi .responses import StreamingResponse , JSONResponse
1414from fastapi .middleware .cors import CORSMiddleware
1515from pydantic import BaseModel , Field
16- from openai import OpenAI , APIError , RateLimitError , AuthenticationError
16+ from openai import AsyncOpenAI , APIError , RateLimitError , AuthenticationError
1717
1818# --- [System Configuration] ---
1919
@@ -112,8 +112,8 @@ class GenerationRequest(BaseModel):
112112
113113class DeepSeekProxy :
114114 def __init__ (self ):
115- self .client = OpenAI (
116- api_key = SILICON_FLOW_API_KEY if SILICON_FLOW_API_KEY else None ,
115+ self .client = AsyncOpenAI (
116+ api_key = SILICON_FLOW_API_KEY if SILICON_FLOW_API_KEY else "dummy_key" ,
117117 base_url = SILICON_FLOW_BASE_URL
118118 )
119119
@@ -174,15 +174,17 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
174174
175175 try :
176176 if openai_params ["stream" ]:
177- # Fetch raw response for headers in stream mode
178- raw_resp = self .client .chat .completions .with_raw_response .create (** openai_params )
177+ # Fetch raw response for headers in stream mode (awaited)
178+ raw_resp = await self .client .chat .completions .with_raw_response .create (** openai_params )
179179 trace_id = raw_resp .headers .get ("X-SiliconCloud-Trace-Id" , initial_request_id )
180+ # raw_resp.parse() returns the AsyncStream
180181 return StreamingResponse (
181182 self ._stream_generator (raw_resp .parse (), trace_id ),
182183 media_type = "text/event-stream"
183184 )
184185 else :
185- raw_resp = self .client .chat .completions .with_raw_response .create (** openai_params )
186+ # Standard response (awaited)
187+ raw_resp = await self .client .chat .completions .with_raw_response .create (** openai_params )
186188 trace_id = raw_resp .headers .get ("X-SiliconCloud-Trace-Id" , initial_request_id )
187189 return self ._format_unary_response (raw_resp .parse (), trace_id )
188190
@@ -204,7 +206,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
204206 }
205207 finish_reason = "null"
206208
207- for chunk in stream :
209+ async for chunk in stream :
208210 if chunk .usage :
209211 accumulated_usage ["total_tokens" ] = chunk .usage .total_tokens
210212 accumulated_usage ["input_tokens" ] = chunk .usage .prompt_tokens
@@ -289,7 +291,6 @@ def create_app() -> FastAPI:
289291 app .add_middleware (
290292 CORSMiddleware , allow_origins = ["*" ], allow_methods = ["*" ], allow_headers = ["*" ]
291293 )
292- proxy = DeepSeekProxy ()
293294
294295 @app .middleware ("http" )
295296 async def request_tracker (request : Request , call_next ):
@@ -319,6 +320,9 @@ async def health_check():
319320 async def generation (request : Request , body : GenerationRequest = None ):
320321 request_id = request .headers .get ("x-request-id" , str (uuid .uuid4 ()))
321322
323+ # Instantiate Proxy Per Request
324+ proxy = DeepSeekProxy ()
325+
322326 if not body :
323327 try :
324328 raw_json = await request .json ()
@@ -331,6 +335,7 @@ async def generation(request: Request, body: GenerationRequest = None):
331335 if body :
332336 logger .info (f"[Shadow] Validating request against upstream..." )
333337 try :
338+ # Async generate call on the local instance
334339 await proxy .generate (body , f"shadow-{ request_id } " )
335340 except Exception as e :
336341 logger .error (f"[Shadow] Validation Exception: { str (e )} " )
0 commit comments