Skip to content

Commit eea98de

Browse files
committed
Update mock_server.py
1 parent 97fc65c commit eea98de

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

tests/mock_server.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,13 @@ class GenerationRequest(BaseModel):
114114
# --- [DeepSeek Proxy Logic] ---
115115

116116
class DeepSeekProxy:
117-
def __init__(self, api_key: str):
117+
def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None):
118118
# We instantiate a new client per request to ensure isolation of user credentials
119119
self.client = AsyncOpenAI(
120120
api_key=api_key,
121121
base_url=SILICON_FLOW_BASE_URL,
122-
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0)
122+
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0),
123+
default_headers=extra_headers # 透传 Header
123124
)
124125

125126
def _get_mapped_model(self, request_model: str) -> str:
@@ -363,16 +364,11 @@ async def generation(
363364

364365
# --- [Logic Switch: Mock vs Production] ---
365366

367+
api_key = "dummy-key"
366368
if SERVER_STATE.is_mock_mode:
367369
# In MOCK mode, we can optionally use the shadow/env key
368370
api_key = _MOCK_ENV_API_KEY or "mock-key"
369-
else:
370-
# --- [PRODUCTION MODE] ---
371-
# STRICTLY Require Header. Do not fallback to env vars.
372-
if not authorization:
373-
logger.warning(f"Rejected request {request_id}: Missing Authorization Header")
374-
raise HTTPException(status_code=401, detail="Missing Authorization header")
375-
371+
elif authorization:
376372
if not authorization.startswith("Bearer "):
377373
logger.warning(f"Rejected request {request_id}: Invalid Authorization Format")
378374
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'")
@@ -381,8 +377,13 @@ async def generation(
381377
api_key = authorization.replace("Bearer ", "")
382378

383379
logger.debug(f"using API Key: {api_key[:8]}... for request {request_id}")
384-
# Instantiate Proxy with the specific key (User's or Mock's)
385-
proxy = DeepSeekProxy(api_key=api_key)
380+
381+
# 过滤掉不安全的或由 httpx 库自动管理的 Headers
382+
unsafe_headers = {"host", "content-length", "content-type", "authorization", "connection", "upgrade", "accept-encoding", "transfer-encoding"}
383+
forward_headers = {k: v for k, v in request.headers.items() if k.lower() not in unsafe_headers}
384+
385+
# Instantiate Proxy with the specific key AND headers
386+
proxy = DeepSeekProxy(api_key=api_key, extra_headers=forward_headers)
386387

387388
# Parse Body if not injected
388389
if not body:
@@ -432,12 +433,19 @@ async def dynamic_path_generation(
432433
request: Request,
433434
authorization: Optional[str] = Header(None)
434435
):
435-
# 1. Strict Auth (No Mock Support)
436-
if not authorization or not authorization.startswith("Bearer "):
437-
raise HTTPException(status_code=401, detail="Invalid Authorization header")
436+
api_key = "dummy-key"
437+
if authorization:
438+
if not authorization.startswith("Bearer "):
439+
logger.warning("Rejected request: Invalid Authorization Format")
440+
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'")
441+
api_key = authorization.replace("Bearer ", "")
438442

439443
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
440-
proxy = DeepSeekProxy(api_key=authorization.replace("Bearer ", ""))
444+
445+
unsafe_headers = {"host", "content-length", "content-type", "authorization", "connection", "upgrade", "accept-encoding", "transfer-encoding"}
446+
forward_headers = {k: v for k, v in request.headers.items() if k.lower() not in unsafe_headers}
447+
448+
proxy = DeepSeekProxy(api_key=api_key, extra_headers=forward_headers)
441449

442450
# 2. Parse, Inject Model, and Validate
443451
try:

0 commit comments

Comments
 (0)