1919# --- [System Configuration] ---
2020
2121logging .basicConfig (
22- level = logging .DEBUG ,
22+ level = logging .DEBUG , # Switched to INFO for production noise reduction
2323 format = "%(asctime)s.%(msecs)03d | %(levelname)s | %(process)d | %(message)s" ,
2424 datefmt = "%H:%M:%S"
2525)
2828# Upstream Base URL
2929SILICON_FLOW_BASE_URL = os .getenv ("SILICON_FLOW_BASE_URL" , "https://api.siliconflow.cn/v1" )
3030
31- # MOCK/TEST ONLY
31+ # MOCK/TEST ONLY: This key is never used in the production generation path
3232_MOCK_ENV_API_KEY = os .getenv ("SILICON_FLOW_API_KEY" )
3333
3434MODEL_MAPPING = {
3939 "default" : "deepseek-ai/DeepSeek-V3"
4040}
4141
42- # Headers that should NOT be forwarded to upstream to avoid conflicts
43- EXCLUDED_UPSTREAM_HEADERS = {
44- "host" , "content-length" , "content-type" , "connection" ,
45- "upgrade" , "accept-encoding" , "transfer-encoding" ,
46- "keep-alive" , "proxy-authorization" , "authorization"
47- }
48-
4942# --- [Shared State] ---
5043
5144class ServerState :
@@ -121,16 +114,11 @@ class GenerationRequest(BaseModel):
121114# --- [DeepSeek Proxy Logic] ---
122115
123116class DeepSeekProxy :
124- def __init__ (self , api_key : str , extra_headers : Optional [Dict [str , str ]] = None ):
125- """
126- :param api_key: The API key for the upstream service.
127- :param extra_headers: Custom headers to forward to the upstream (e.g. X-Request-ID).
128- """
117+ def __init__ (self , api_key : str ):
118+ # We instantiate a new client per request to ensure isolation of user credentials
129119 self .client = AsyncOpenAI (
130120 api_key = api_key ,
131121 base_url = SILICON_FLOW_BASE_URL ,
132- # ✅ Inject headers here so they are sent with every request made by this client
133- default_headers = extra_headers ,
134122 timeout = httpx .Timeout (connect = 10.0 , read = 600.0 , write = 600.0 , pool = 10.0 )
135123 )
136124
@@ -156,12 +144,14 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
156144 async def generate (self , req_data : GenerationRequest , initial_request_id : str ):
157145 params = req_data .parameters
158146
147+ # Validation: Tools require message format
159148 if params .tools and params .result_format != "message" :
160149 return JSONResponse (
161150 status_code = 400 ,
162151 content = {"code" : "InvalidParameter" , "message" : "When 'tools' are provided, 'result_format' must be 'message'." }
163152 )
164153
154+ # Validation: R1 + Tools constraint
165155 is_r1 = "deepseek-r1" in req_data .model or params .enable_thinking
166156 if is_r1 and params .tool_choice and isinstance (params .tool_choice , dict ):
167157 return JSONResponse (
@@ -241,6 +231,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
241231 delta_content = delta .content if delta and delta .content else ""
242232 delta_reasoning = (getattr (delta , "reasoning_content" , "" ) or "" ) if delta else ""
243233
234+ # ✅ 累积完整内容
244235 if delta_content :
245236 full_text += delta_content
246237 if delta_reasoning :
@@ -253,6 +244,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
253244 if chunk .choices and chunk .choices [0 ].finish_reason :
254245 finish_reason = chunk .choices [0 ].finish_reason
255246
247+ # ✅ 关键:stop 包输出“完整累积内容”,避免最后一包是空导致聚合为空
256248 if finish_reason != "null" :
257249 content_to_send = full_text
258250 reasoning_to_send = full_reasoning
@@ -313,17 +305,6 @@ def _format_unary_response(self, completion, request_id: str):
313305
314306# --- [FastAPI App & Lifecycle] ---
315307
316- def get_forwardable_headers (request : Request ) -> Dict [str , str ]:
317- """
318- Extracts headers from the request that are safe to forward to the upstream.
319- Filters out hop-by-hop headers, content-related headers (which are rebuilt),
320- and auth (which is handled separately).
321- """
322- return {
323- k : v for k , v in request .headers .items ()
324- if k .lower () not in EXCLUDED_UPSTREAM_HEADERS
325- }
326-
327308@asynccontextmanager
328309async def lifespan (app : FastAPI ):
329310 stop_event = threading .Event ()
@@ -357,6 +338,7 @@ async def request_tracker(request: Request, call_next):
357338 finally :
358339 SERVER_STATE .decrement_request ()
359340 duration = (time .time () - start_time ) * 1000
341+ # Reduced log noise for healthy requests, kept for errors or slow ones if needed
360342 if duration > 1000 :
361343 logger .warning (f"{ request .method } { request .url .path } - SLOW { duration :.2f} ms" )
362344
@@ -379,36 +361,57 @@ async def generation(
379361 ):
380362 request_id = request .headers .get ("x-request-id" , str (uuid .uuid4 ()))
381363
364+ # --- [Logic Switch: Mock vs Production] ---
365+
382366 if SERVER_STATE .is_mock_mode :
367+ # In MOCK mode, we can optionally use the shadow/env key
383368 api_key = _MOCK_ENV_API_KEY or "mock-key"
384369 else :
370+ # --- [PRODUCTION MODE] ---
371+ # STRICTLY Require Header. Do not fallback to env vars.
385372 if not authorization :
373+ logger .warning (f"Rejected request { request_id } : Missing Authorization Header" )
386374 raise HTTPException (status_code = 401 , detail = "Missing Authorization header" )
387- if not authorization .startswith ("Bearer " ):
388- raise HTTPException (status_code = 401 , detail = "Invalid Authorization header format" )
389- api_key = authorization .replace ("Bearer " , "" )
390375
391- # ✅ 1. Get filtered headers
392- upstream_headers = get_forwardable_headers (request )
376+ if not authorization .startswith ("Bearer " ):
377+ logger .warning (f"Rejected request { request_id } : Invalid Authorization Format" )
378+ raise HTTPException (status_code = 401 , detail = "Invalid Authorization header format. Expected 'Bearer <token>'" )
393379
394- logger .debug (f"Request { request_id } Forwarding Headers: { upstream_headers .keys ()} " )
380+ # Transparently forward the user's key
381+ api_key = authorization .replace ("Bearer " , "" )
395382
396- # ✅ 2. Initialize Proxy with headers
397- proxy = DeepSeekProxy (api_key = api_key , extra_headers = upstream_headers )
383+ logger .debug (f"using API Key: { api_key [:8 ]} ... for request { request_id } " )
384+ # Instantiate Proxy with the specific key (User's or Mock's)
385+ proxy = DeepSeekProxy (api_key = api_key )
398386
387+ # Parse Body if not injected
399388 if not body :
400389 try :
401390 raw_json = await request .json ()
402391 body = GenerationRequest (** raw_json )
403392 except Exception as e :
404393 raise HTTPException (status_code = 400 , detail = f"Invalid JSON: { e } " )
405394
395+ # --- [Auto-enable stream mode based on header] ---
406396 accept_header = request .headers .get ("accept" , "" )
407397 if "text/event-stream" in accept_header and body .parameters :
398+ logger .info ("SSE client detected, forcing incremental_output=True" )
408399 body .parameters .incremental_output = True
409400
401+
402+ # --- [Mock Handling] ---
410403 if SERVER_STATE .is_mock_mode :
411- # ... (Mock logic omitted for brevity, logic remains same)
404+ if body :
405+ # Shadow Traffic Logic (Optional validation against upstream)
406+ try :
407+ # We only perform shadow traffic if a key is actually available
408+ if _MOCK_ENV_API_KEY :
409+ shadow_proxy = DeepSeekProxy (api_key = _MOCK_ENV_API_KEY )
410+ # Fire and forget (or await if validation is strict)
411+ # await shadow_proxy.generate(body, f"shadow-{request_id}")
412+ except Exception :
413+ pass # Swallow shadow errors
414+
412415 try :
413416 raw_body = await request .json ()
414417 SERVER_STATE .request_queue .put (raw_body )
@@ -419,6 +422,8 @@ async def generation(
419422 except Exception as e :
420423 return JSONResponse (status_code = 500 , content = {"code" : "MockError" , "message" : str (e )})
421424
425+ # --- [Production Handling] ---
426+ # Forward request to upstream
422427 return await proxy .generate (body , request_id )
423428
424429 @app .post ("/siliconflow/models/{model_path:path}" )
@@ -427,42 +432,38 @@ async def dynamic_path_generation(
427432 request : Request ,
428433 authorization : Optional [str ] = Header (None )
429434 ):
435+ # 1. Strict Auth (No Mock Support)
430436 if not authorization or not authorization .startswith ("Bearer " ):
431437 raise HTTPException (status_code = 401 , detail = "Invalid Authorization header" )
432438
433439 request_id = request .headers .get ("x-request-id" , str (uuid .uuid4 ()))
440+ proxy = DeepSeekProxy (api_key = authorization .replace ("Bearer " , "" ))
434441
435- # ✅ Header Forwarding Logic
436- upstream_headers = get_forwardable_headers (request )
437- proxy = DeepSeekProxy (api_key = authorization .replace ("Bearer " , "" ), extra_headers = upstream_headers )
438-
442+ # 2. Parse, Inject Model, and Validate
439443 try :
440444 payload = await request .json ()
441- payload ["model" ] = model_path
445+ payload ["model" ] = model_path # Force set model from URL
442446 body = GenerationRequest (** payload )
443447 except Exception as e :
444448 raise HTTPException (status_code = 400 , detail = f"Invalid Request: { e } " )
445449
450+ # 3. Handle SSE
446451 if "text/event-stream" in request .headers .get ("accept" , "" ) and body .parameters :
447452 body .parameters .incremental_output = True
448453
454+ # 4. Generate
449455 return await proxy .generate (body , request_id )
450456
451457 @app .api_route ("/{path_name:path}" , methods = ["GET" , "POST" , "DELETE" , "PUT" ])
452458 async def catch_all (path_name : str , request : Request ):
459+ # Catch-all only valid in Mock Mode
453460 if SERVER_STATE .is_mock_mode :
454461 try :
455462 body = None
456463 if request .method in ["POST" , "PUT" ]:
457464 try : body = await request .json ()
458465 except : pass
459- # We forward headers in the mock record too, in case tests need to verify them
460- req_record = {
461- "path" : f"/{ path_name } " ,
462- "method" : request .method ,
463- "headers" : dict (request .headers ),
464- "body" : body
465- }
466+ req_record = {"path" : f"/{ path_name } " , "method" : request .method , "headers" : dict (request .headers ), "body" : body }
466467 SERVER_STATE .request_queue .put (req_record )
467468 response_data = SERVER_STATE .response_queue .get (timeout = 5 )
468469 response_json = json .loads (response_data ) if isinstance (response_data , str ) else response_data
0 commit comments