Skip to content

Commit f4952b7

Browse files
committed
Update mock_server.py
1 parent 3e3f0e9 commit f4952b7

File tree

1 file changed

+48
-49
lines changed

1 file changed

+48
-49
lines changed

tests/mock_server.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# --- [System Configuration] ---
2020

2121
logging.basicConfig(
22-
level=logging.DEBUG, # Switched to INFO for production noise reduction
22+
level=logging.DEBUG,
2323
format="%(asctime)s.%(msecs)03d | %(levelname)s | %(process)d | %(message)s",
2424
datefmt="%H:%M:%S"
2525
)
@@ -28,7 +28,7 @@
2828
# Upstream Base URL
2929
SILICON_FLOW_BASE_URL = os.getenv("SILICON_FLOW_BASE_URL", "https://api.siliconflow.cn/v1")
3030

31-
# MOCK/TEST ONLY: This key is never used in the production generation path
31+
# MOCK/TEST ONLY
3232
_MOCK_ENV_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
3333

3434
MODEL_MAPPING = {
@@ -39,6 +39,13 @@
3939
"default": "deepseek-ai/DeepSeek-V3"
4040
}
4141

42+
# Headers that should NOT be forwarded to upstream to avoid conflicts
43+
EXCLUDED_UPSTREAM_HEADERS = {
44+
"host", "content-length", "content-type", "connection",
45+
"upgrade", "accept-encoding", "transfer-encoding",
46+
"keep-alive", "proxy-authorization", "authorization"
47+
}
48+
4249
# --- [Shared State] ---
4350

4451
class ServerState:
@@ -114,11 +121,16 @@ class GenerationRequest(BaseModel):
114121
# --- [DeepSeek Proxy Logic] ---
115122

116123
class DeepSeekProxy:
117-
def __init__(self, api_key: str):
118-
# We instantiate a new client per request to ensure isolation of user credentials
124+
def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None):
125+
"""
126+
:param api_key: The API key for the upstream service.
127+
:param extra_headers: Custom headers to forward to the upstream (e.g. X-Request-ID).
128+
"""
119129
self.client = AsyncOpenAI(
120130
api_key=api_key,
121131
base_url=SILICON_FLOW_BASE_URL,
132+
# ✅ Inject headers here so they are sent with every request made by this client
133+
default_headers=extra_headers,
122134
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0)
123135
)
124136

@@ -144,14 +156,12 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
144156
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
145157
params = req_data.parameters
146158

147-
# Validation: Tools require message format
148159
if params.tools and params.result_format != "message":
149160
return JSONResponse(
150161
status_code=400,
151162
content={"code": "InvalidParameter", "message": "When 'tools' are provided, 'result_format' must be 'message'."}
152163
)
153164

154-
# Validation: R1 + Tools constraint
155165
is_r1 = "deepseek-r1" in req_data.model or params.enable_thinking
156166
if is_r1 and params.tool_choice and isinstance(params.tool_choice, dict):
157167
return JSONResponse(
@@ -231,7 +241,6 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
231241
delta_content = delta.content if delta and delta.content else ""
232242
delta_reasoning = (getattr(delta, "reasoning_content", "") or "") if delta else ""
233243

234-
# ✅ 累积完整内容
235244
if delta_content:
236245
full_text += delta_content
237246
if delta_reasoning:
@@ -244,7 +253,6 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
244253
if chunk.choices and chunk.choices[0].finish_reason:
245254
finish_reason = chunk.choices[0].finish_reason
246255

247-
# ✅ 关键:stop 包输出“完整累积内容”,避免最后一包是空导致聚合为空
248256
if finish_reason != "null":
249257
content_to_send = full_text
250258
reasoning_to_send = full_reasoning
@@ -305,6 +313,17 @@ def _format_unary_response(self, completion, request_id: str):
305313

306314
# --- [FastAPI App & Lifecycle] ---
307315

316+
def get_forwardable_headers(request: Request) -> Dict[str, str]:
317+
"""
318+
Extracts headers from the request that are safe to forward to the upstream.
319+
Filters out hop-by-hop headers, content-related headers (which are rebuilt),
320+
and auth (which is handled separately).
321+
"""
322+
return {
323+
k: v for k, v in request.headers.items()
324+
if k.lower() not in EXCLUDED_UPSTREAM_HEADERS
325+
}
326+
308327
@asynccontextmanager
309328
async def lifespan(app: FastAPI):
310329
stop_event = threading.Event()
@@ -338,7 +357,6 @@ async def request_tracker(request: Request, call_next):
338357
finally:
339358
SERVER_STATE.decrement_request()
340359
duration = (time.time() - start_time) * 1000
341-
# Reduced log noise for healthy requests, kept for errors or slow ones if needed
342360
if duration > 1000:
343361
logger.warning(f"{request.method} {request.url.path} - SLOW {duration:.2f}ms")
344362

@@ -361,57 +379,36 @@ async def generation(
361379
):
362380
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
363381

364-
# --- [Logic Switch: Mock vs Production] ---
365-
366382
if SERVER_STATE.is_mock_mode:
367-
# In MOCK mode, we can optionally use the shadow/env key
368383
api_key = _MOCK_ENV_API_KEY or "mock-key"
369384
else:
370-
# --- [PRODUCTION MODE] ---
371-
# STRICTLY Require Header. Do not fallback to env vars.
372385
if not authorization:
373-
logger.warning(f"Rejected request {request_id}: Missing Authorization Header")
374386
raise HTTPException(status_code=401, detail="Missing Authorization header")
375-
376387
if not authorization.startswith("Bearer "):
377-
logger.warning(f"Rejected request {request_id}: Invalid Authorization Format")
378-
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'")
379-
380-
# Transparently forward the user's key
388+
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
381389
api_key = authorization.replace("Bearer ", "")
382390

383-
logger.debug(f"using API Key: {api_key[:8]}... for request {request_id}")
384-
# Instantiate Proxy with the specific key (User's or Mock's)
385-
proxy = DeepSeekProxy(api_key=api_key)
391+
# ✅ 1. Get filtered headers
392+
upstream_headers = get_forwardable_headers(request)
393+
394+
logger.debug(f"Request {request_id} Forwarding Headers: {upstream_headers.keys()}")
395+
396+
# ✅ 2. Initialize Proxy with headers
397+
proxy = DeepSeekProxy(api_key=api_key, extra_headers=upstream_headers)
386398

387-
# Parse Body if not injected
388399
if not body:
389400
try:
390401
raw_json = await request.json()
391402
body = GenerationRequest(**raw_json)
392403
except Exception as e:
393404
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
394405

395-
# --- [Auto-enable stream mode based on header] ---
396406
accept_header = request.headers.get("accept", "")
397407
if "text/event-stream" in accept_header and body.parameters:
398-
logger.info("SSE client detected, forcing incremental_output=True")
399408
body.parameters.incremental_output = True
400409

401-
402-
# --- [Mock Handling] ---
403410
if SERVER_STATE.is_mock_mode:
404-
if body:
405-
# Shadow Traffic Logic (Optional validation against upstream)
406-
try:
407-
# We only perform shadow traffic if a key is actually available
408-
if _MOCK_ENV_API_KEY:
409-
shadow_proxy = DeepSeekProxy(api_key=_MOCK_ENV_API_KEY)
410-
# Fire and forget (or await if validation is strict)
411-
# await shadow_proxy.generate(body, f"shadow-{request_id}")
412-
except Exception:
413-
pass # Swallow shadow errors
414-
411+
# ... (Mock logic omitted for brevity, logic remains same)
415412
try:
416413
raw_body = await request.json()
417414
SERVER_STATE.request_queue.put(raw_body)
@@ -422,8 +419,6 @@ async def generation(
422419
except Exception as e:
423420
return JSONResponse(status_code=500, content={"code": "MockError", "message": str(e)})
424421

425-
# --- [Production Handling] ---
426-
# Forward request to upstream
427422
return await proxy.generate(body, request_id)
428423

429424
@app.post("/siliconflow/models/{model_path:path}")
@@ -432,38 +427,42 @@ async def dynamic_path_generation(
432427
request: Request,
433428
authorization: Optional[str] = Header(None)
434429
):
435-
# 1. Strict Auth (No Mock Support)
436430
if not authorization or not authorization.startswith("Bearer "):
437431
raise HTTPException(status_code=401, detail="Invalid Authorization header")
438432

439433
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
440-
proxy = DeepSeekProxy(api_key=authorization.replace("Bearer ", ""))
441434

442-
# 2. Parse, Inject Model, and Validate
435+
# ✅ Header Forwarding Logic
436+
upstream_headers = get_forwardable_headers(request)
437+
proxy = DeepSeekProxy(api_key=authorization.replace("Bearer ", ""), extra_headers=upstream_headers)
438+
443439
try:
444440
payload = await request.json()
445-
payload["model"] = model_path # Force set model from URL
441+
payload["model"] = model_path
446442
body = GenerationRequest(**payload)
447443
except Exception as e:
448444
raise HTTPException(status_code=400, detail=f"Invalid Request: {e}")
449445

450-
# 3. Handle SSE
451446
if "text/event-stream" in request.headers.get("accept", "") and body.parameters:
452447
body.parameters.incremental_output = True
453448

454-
# 4. Generate
455449
return await proxy.generate(body, request_id)
456450

457451
@app.api_route("/{path_name:path}", methods=["GET", "POST", "DELETE", "PUT"])
458452
async def catch_all(path_name: str, request: Request):
459-
# Catch-all only valid in Mock Mode
460453
if SERVER_STATE.is_mock_mode:
461454
try:
462455
body = None
463456
if request.method in ["POST", "PUT"]:
464457
try: body = await request.json()
465458
except: pass
466-
req_record = {"path": f"/{path_name}", "method": request.method, "headers": dict(request.headers), "body": body}
459+
# We forward headers in the mock record too, in case tests need to verify them
460+
req_record = {
461+
"path": f"/{path_name}",
462+
"method": request.method,
463+
"headers": dict(request.headers),
464+
"body": body
465+
}
467466
SERVER_STATE.request_queue.put(req_record)
468467
response_data = SERVER_STATE.response_queue.get(timeout=5)
469468
response_json = json.loads(response_data) if isinstance(response_data, str) else response_data

0 commit comments

Comments
 (0)