99from contextlib import asynccontextmanager
1010
1111import uvicorn
12- from fastapi import FastAPI , HTTPException , Request
12+ from fastapi import FastAPI , HTTPException , Request , Response
1313from fastapi .responses import StreamingResponse , JSONResponse
1414from fastapi .middleware .cors import CORSMiddleware
1515from pydantic import BaseModel , Field
@@ -67,7 +67,6 @@ def decrement_request(self):
6767 with self .lock :
6868 self .active_requests -= 1
6969
70- # [FIX 1] 添加线程安全的快照读取,确保并发状态一致性
7170 @property
7271 def snapshot (self ):
7372 """Returns a consistent snapshot of the state."""
@@ -101,9 +100,7 @@ class Parameters(BaseModel):
101100 stop : Optional [Union [str , List [str ]]] = None
102101 enable_thinking : bool = False
103102 thinking_budget : Optional [int ] = None
104- # [ADDED] Tools Support
105103 tools : Optional [List [Dict [str , Any ]]] = None
106- # Allowed: "none", "auto", "required" (str) OR {"type": "function", ...} (dict)
107104 tool_choice : Optional [Union [str , Dict [str , Any ]]] = None
108105
109106class GenerationRequest (BaseModel ):
@@ -139,23 +136,15 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
139136 messages .append ({"role" : "user" , "content" : input_data .prompt })
140137 return messages
141138
142- async def generate (self , req_data : GenerationRequest , request_id : str ):
143- """
144- Standard generation logic with strict invariant checks.
145- """
139+ async def generate (self , req_data : GenerationRequest , initial_request_id : str ):
146140 params = req_data .parameters
147141
148- # --- [Invariant Checks] ---
149- # 1. Format Constraint (Predicate A)
150142 if params .tools and params .result_format != "message" :
151143 return JSONResponse (
152144 status_code = 400 ,
153145 content = {"code" : "InvalidParameter" , "message" : "When 'tools' are provided, 'result_format' must be 'message'." }
154146 )
155147
156- # 2. R1 Orthogonality (Predicate B)
157- # DeepSeek R1 Thinking Mode is mutually exclusive with FORCED SPECIFIC tool choice (Dict).
158- # However, abstract constraints like "required" (String) are allowed.
159148 is_r1 = "deepseek-r1" in req_data .model or params .enable_thinking
160149 if is_r1 and params .tool_choice and isinstance (params .tool_choice , dict ):
161150 return JSONResponse (
@@ -174,11 +163,9 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
174163 "stream" : params .incremental_output or params .enable_thinking ,
175164 }
176165
177- # [ADDED] Pass tools to upstream if present
178166 if params .tools :
179167 openai_params ["tools" ] = params .tools
180168 if params .tool_choice :
181- # This will pass "required" (str) effectively to OpenAI/SiliconFlow
182169 openai_params ["tool_choice" ] = params .tool_choice
183170
184171 if params .max_tokens : openai_params ["max_tokens" ] = params .max_tokens
@@ -187,13 +174,17 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
187174
188175 try :
189176 if openai_params ["stream" ]:
177+ # Fetch raw response for headers in stream mode
178+ raw_resp = self .client .chat .completions .with_raw_response .create (** openai_params )
179+ trace_id = raw_resp .headers .get ("X-SiliconCloud-Trace-Id" , initial_request_id )
190180 return StreamingResponse (
191- self ._stream_generator (openai_params , request_id ),
181+ self ._stream_generator (raw_resp . parse (), trace_id ),
192182 media_type = "text/event-stream"
193183 )
194184 else :
195- completion = self .client .chat .completions .create (** openai_params )
196- return self ._format_unary_response (completion , request_id )
185+ raw_resp = self .client .chat .completions .with_raw_response .create (** openai_params )
186+ trace_id = raw_resp .headers .get ("X-SiliconCloud-Trace-Id" , initial_request_id )
187+ return self ._format_unary_response (raw_resp .parse (), trace_id )
197188
198189 except APIError as e :
199190 logger .error (f"Upstream API Error: { str (e )} " )
@@ -203,20 +194,10 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
203194
204195 return JSONResponse (
205196 status_code = e .status_code or 500 ,
206- content = {"code" : error_code , "message" : str (e ), "request_id" : request_id }
197+ content = {"code" : error_code , "message" : str (e ), "request_id" : initial_request_id }
207198 )
208199
209- async def _stream_generator (self , openai_params : Dict , request_id : str ) -> AsyncGenerator [str , None ]:
210- if "stream_options" not in openai_params :
211- openai_params ["stream_options" ] = {"include_usage" : True }
212-
213- try :
214- stream = self .client .chat .completions .create (** openai_params )
215- except Exception as e :
216- logger .error (f"Stream creation failed: { e } " )
217- yield f"data: { json .dumps ({'code' : 'StreamError' , 'message' : str (e )}, ensure_ascii = False )} \n \n "
218- return
219-
200+ async def _stream_generator (self , stream , request_id : str ) -> AsyncGenerator [str , None ]:
220201 accumulated_usage = {
221202 "total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 ,
222203 "output_tokens_details" : {"text_tokens" : 0 , "reasoning_tokens" : 0 }
@@ -234,34 +215,21 @@ async def _stream_generator(self, openai_params: Dict, request_id: str) -> Async
234215 accumulated_usage ["output_tokens_details" ]["text_tokens" ] = accumulated_usage ["output_tokens" ] - accumulated_usage ["output_tokens_details" ]["reasoning_tokens" ]
235216
236217 delta = chunk .choices [0 ].delta if chunk .choices else None
237-
238218 content = delta .content if delta and delta .content else ""
239219 reasoning = getattr (delta , "reasoning_content" , "" ) if delta else ""
240220
241221 tool_calls = None
242222 if delta and delta .tool_calls :
243- # Forward the raw list of tool call chunks
244- # Note: model_dump() is retained per original design, ensuring Pydantic serialization
245223 tool_calls = [tc .model_dump () for tc in delta .tool_calls ]
246224
247225 if chunk .choices and chunk .choices [0 ].finish_reason :
248226 finish_reason = chunk .choices [0 ].finish_reason
249227
250- message_body = {
251- "role" : "assistant" ,
252- "content" : content ,
253- "reasoning_content" : reasoning
254- }
255- if tool_calls :
256- message_body ["tool_calls" ] = tool_calls
228+ message_body = {"role" : "assistant" , "content" : content , "reasoning_content" : reasoning }
229+ if tool_calls : message_body ["tool_calls" ] = tool_calls
257230
258231 response_body = {
259- "output" : {
260- "choices" : [{
261- "message" : message_body ,
262- "finish_reason" : finish_reason
263- }]
264- },
232+ "output" : {"choices" : [{"message" : message_body , "finish_reason" : finish_reason }]},
265233 "usage" : accumulated_usage ,
266234 "request_id" : request_id
267235 }
@@ -271,12 +239,11 @@ async def _stream_generator(self, openai_params: Dict, request_id: str) -> Async
271239 def _format_unary_response (self , completion , request_id : str ):
272240 choice = completion .choices [0 ]
273241 msg = choice .message
274-
275242 usage_data = {
276243 "total_tokens" : completion .usage .total_tokens ,
277244 "input_tokens" : completion .usage .prompt_tokens ,
278245 "output_tokens" : completion .usage .completion_tokens ,
279- "output_tokens_details" : {"text_tokens" : 0 , "reasoning_tokens" : 0 }
246+ "output_tokens_details" : {"text_tokens" : 0 , "reasoning_tokens" : 0 }
280247 }
281248 details = getattr (completion .usage , "completion_tokens_details" , None )
282249 if details :
@@ -292,12 +259,7 @@ def _format_unary_response(self, completion, request_id: str):
292259 message_body ["tool_calls" ] = [tc .model_dump () for tc in msg .tool_calls ]
293260
294261 response_body = {
295- "output" : {
296- "choices" : [{
297- "message" : message_body ,
298- "finish_reason" : choice .finish_reason
299- }]
300- },
262+ "output" : {"choices" : [{"message" : message_body , "finish_reason" : choice .finish_reason }]},
301263 "usage" : usage_data ,
302264 "request_id" : request_id
303265 }
@@ -311,15 +273,12 @@ async def lifespan(app: FastAPI):
311273 def epoch_clock ():
312274 while not stop_event .is_set ():
313275 time .sleep (2 )
314-
315- # [FIX 1 Usage & FIX 2] 使用快照读取状态,并将日志级别改为 INFO
316276 state = SERVER_STATE .snapshot
317277 if state ["active_requests" ] > 0 or state ["is_mock_mode" ]:
318278 logger .info (
319279 f"[Epoch Clock] Active Requests: { state ['active_requests' ]} | "
320280 f"Mode: { 'MOCK' if state ['is_mock_mode' ] else 'PROXY' } "
321281 )
322-
323282 monitor_thread = threading .Thread (target = epoch_clock , daemon = True )
324283 monitor_thread .start ()
325284 yield
@@ -344,12 +303,8 @@ async def request_tracker(request: Request, call_next):
344303 duration = (time .time () - start_time ) * 1000
345304 logger .info (f"{ request .method } { request .url .path } - { duration :.2f} ms" )
346305
347- # --- [New Endpoint] Health Check ---
348306 @app .get ("/health_check" )
349307 async def health_check ():
350- """
351- Liveness probe verifying server status and current mode.
352- """
353308 return JSONResponse (
354309 status_code = 200 ,
355310 content = {
@@ -372,39 +327,25 @@ async def generation(request: Request, body: GenerationRequest = None):
372327 if not SERVER_STATE .is_mock_mode :
373328 raise HTTPException (status_code = 400 , detail = f"Invalid JSON: { e } " )
374329
375- # === [Logic Branch]: Mock Mode with Shadow Verification ===
376330 if SERVER_STATE .is_mock_mode :
377331 if body :
378- logger .info (f"[Shadow] Validating request against upstream: { body . model_dump_json ( exclude_none = True ) } " )
332+ logger .info (f"[Shadow] Validating request against upstream... " )
379333 try :
380- shadow_resp = await proxy .generate (body , f"shadow-{ request_id } " )
381- if isinstance (shadow_resp , StreamingResponse ):
382- async for _ in shadow_resp .body_iterator : pass
383- logger .info ("[Shadow] Upstream stream validation: PASSED" )
384- elif isinstance (shadow_resp , JSONResponse ):
385- status = shadow_resp .status_code
386- if 200 <= status < 300 :
387- logger .info (f"[Shadow] Upstream unary validation: PASSED (Status { status } )" )
388- else :
389- logger .warning (f"[Shadow] Upstream validation FAILED (Status { status } )" )
334+ await proxy .generate (body , f"shadow-{ request_id } " )
390335 except Exception as e :
391336 logger .error (f"[Shadow] Validation Exception: { str (e )} " )
392337
393338 try :
394339 raw_body = await request .json ()
395340 SERVER_STATE .request_queue .put (raw_body )
396341 response_data = SERVER_STATE .response_queue .get (timeout = 10 )
397- if isinstance (response_data , str ):
398- response_json = json .loads (response_data )
399- else :
400- response_json = response_data
342+ response_json = json .loads (response_data ) if isinstance (response_data , str ) else response_data
401343 status_code = response_json .pop ("status_code" , 200 )
402344 return JSONResponse (content = response_json , status_code = status_code )
403345 except Exception as e :
404346 logger .critical (f"[Mock] DEADLOCK/ERROR: { e } " )
405347 return JSONResponse (status_code = 500 , content = {"code" : "MockError" , "message" : f"Mock Server Error: { str (e )} " })
406348
407- # === [Logic Branch]: Production Proxy Mode ===
408349 return await proxy .generate (body , request_id )
409350
410351 @app .api_route ("/{path_name:path}" , methods = ["GET" , "POST" , "DELETE" , "PUT" ])
@@ -413,21 +354,12 @@ async def catch_all(path_name: str, request: Request):
413354 try :
414355 body = None
415356 if request .method in ["POST" , "PUT" ]:
416- try :
417- body = await request .json ()
357+ try : body = await request .json ()
418358 except : pass
419- req_record = {
420- "path" : f"/{ path_name } " ,
421- "method" : request .method ,
422- "headers" : dict (request .headers ),
423- "body" : body
424- }
359+ req_record = {"path" : f"/{ path_name } " , "method" : request .method , "headers" : dict (request .headers ), "body" : body }
425360 SERVER_STATE .request_queue .put (req_record )
426361 response_data = SERVER_STATE .response_queue .get (timeout = 5 )
427- if isinstance (response_data , str ):
428- response_json = json .loads (response_data )
429- else :
430- response_json = response_data
362+ response_json = json .loads (response_data ) if isinstance (response_data , str ) else response_data
431363 status_code = response_json .pop ("status_code" , 200 )
432364 return JSONResponse (content = response_json , status_code = status_code )
433365 except Exception as e :
@@ -451,21 +383,16 @@ def __init__(self) -> None:
451383
452384def create_mock_server (* args , ** kwargs ):
453385 mock_server = MockServer ()
454- proc = multiprocessing .Process (
455- target = run_server_process ,
456- args = (mock_server .requests , mock_server .responses , "0.0.0.0" , 8089 )
457- )
386+ proc = multiprocessing .Process (target = run_server_process , args = (mock_server .requests , mock_server .responses , "0.0.0.0" , 8089 ))
458387 proc .start ()
459388 mock_server .proc = proc
460389 time .sleep (1.5 )
461390 logger .info ("Mock Server (Proxy Mode) started on port 8089" )
462-
463391 if args and hasattr (args [0 ], "addfinalizer" ):
464392 def stop_server ():
465393 if proc .is_alive ():
466394 proc .terminate ()
467395 proc .join ()
468- logger .info ("Mock Server stopped" )
469396 args [0 ].addfinalizer (stop_server )
470397 return mock_server
471398
0 commit comments