Skip to content

Commit 97fc65c

Browse files
committed
Revert "Update mock_server.py"
This reverts commit f4952b7.
1 parent f4952b7 commit 97fc65c

File tree

1 file changed

+49
-48
lines changed

1 file changed

+49
-48
lines changed

tests/mock_server.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# --- [System Configuration] ---
2020

2121
logging.basicConfig(
22-
level=logging.DEBUG,
22+
level=logging.DEBUG, # Switched to INFO for production noise reduction
2323
format="%(asctime)s.%(msecs)03d | %(levelname)s | %(process)d | %(message)s",
2424
datefmt="%H:%M:%S"
2525
)
@@ -28,7 +28,7 @@
2828
# Upstream Base URL
2929
SILICON_FLOW_BASE_URL = os.getenv("SILICON_FLOW_BASE_URL", "https://api.siliconflow.cn/v1")
3030

31-
# MOCK/TEST ONLY
31+
# MOCK/TEST ONLY: This key is never used in the production generation path
3232
_MOCK_ENV_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
3333

3434
MODEL_MAPPING = {
@@ -39,13 +39,6 @@
3939
"default": "deepseek-ai/DeepSeek-V3"
4040
}
4141

42-
# Headers that should NOT be forwarded to upstream to avoid conflicts
43-
EXCLUDED_UPSTREAM_HEADERS = {
44-
"host", "content-length", "content-type", "connection",
45-
"upgrade", "accept-encoding", "transfer-encoding",
46-
"keep-alive", "proxy-authorization", "authorization"
47-
}
48-
4942
# --- [Shared State] ---
5043

5144
class ServerState:
@@ -121,16 +114,11 @@ class GenerationRequest(BaseModel):
121114
# --- [DeepSeek Proxy Logic] ---
122115

123116
class DeepSeekProxy:
124-
def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None):
125-
"""
126-
:param api_key: The API key for the upstream service.
127-
:param extra_headers: Custom headers to forward to the upstream (e.g. X-Request-ID).
128-
"""
117+
def __init__(self, api_key: str):
118+
# We instantiate a new client per request to ensure isolation of user credentials
129119
self.client = AsyncOpenAI(
130120
api_key=api_key,
131121
base_url=SILICON_FLOW_BASE_URL,
132-
# ✅ Inject headers here so they are sent with every request made by this client
133-
default_headers=extra_headers,
134122
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0)
135123
)
136124

@@ -156,12 +144,14 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
156144
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
157145
params = req_data.parameters
158146

147+
# Validation: Tools require message format
159148
if params.tools and params.result_format != "message":
160149
return JSONResponse(
161150
status_code=400,
162151
content={"code": "InvalidParameter", "message": "When 'tools' are provided, 'result_format' must be 'message'."}
163152
)
164153

154+
# Validation: R1 + Tools constraint
165155
is_r1 = "deepseek-r1" in req_data.model or params.enable_thinking
166156
if is_r1 and params.tool_choice and isinstance(params.tool_choice, dict):
167157
return JSONResponse(
@@ -241,6 +231,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
241231
delta_content = delta.content if delta and delta.content else ""
242232
delta_reasoning = (getattr(delta, "reasoning_content", "") or "") if delta else ""
243233

234+
# ✅ 累积完整内容
244235
if delta_content:
245236
full_text += delta_content
246237
if delta_reasoning:
@@ -253,6 +244,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
253244
if chunk.choices and chunk.choices[0].finish_reason:
254245
finish_reason = chunk.choices[0].finish_reason
255246

247+
# ✅ 关键:stop 包输出“完整累积内容”,避免最后一包是空导致聚合为空
256248
if finish_reason != "null":
257249
content_to_send = full_text
258250
reasoning_to_send = full_reasoning
@@ -313,17 +305,6 @@ def _format_unary_response(self, completion, request_id: str):
313305

314306
# --- [FastAPI App & Lifecycle] ---
315307

316-
def get_forwardable_headers(request: Request) -> Dict[str, str]:
317-
"""
318-
Extracts headers from the request that are safe to forward to the upstream.
319-
Filters out hop-by-hop headers, content-related headers (which are rebuilt),
320-
and auth (which is handled separately).
321-
"""
322-
return {
323-
k: v for k, v in request.headers.items()
324-
if k.lower() not in EXCLUDED_UPSTREAM_HEADERS
325-
}
326-
327308
@asynccontextmanager
328309
async def lifespan(app: FastAPI):
329310
stop_event = threading.Event()
@@ -357,6 +338,7 @@ async def request_tracker(request: Request, call_next):
357338
finally:
358339
SERVER_STATE.decrement_request()
359340
duration = (time.time() - start_time) * 1000
341+
# Reduced log noise for healthy requests, kept for errors or slow ones if needed
360342
if duration > 1000:
361343
logger.warning(f"{request.method} {request.url.path} - SLOW {duration:.2f}ms")
362344

@@ -379,36 +361,57 @@ async def generation(
379361
):
380362
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
381363

364+
# --- [Logic Switch: Mock vs Production] ---
365+
382366
if SERVER_STATE.is_mock_mode:
367+
# In MOCK mode, we can optionally use the shadow/env key
383368
api_key = _MOCK_ENV_API_KEY or "mock-key"
384369
else:
370+
# --- [PRODUCTION MODE] ---
371+
# STRICTLY Require Header. Do not fallback to env vars.
385372
if not authorization:
373+
logger.warning(f"Rejected request {request_id}: Missing Authorization Header")
386374
raise HTTPException(status_code=401, detail="Missing Authorization header")
387-
if not authorization.startswith("Bearer "):
388-
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
389-
api_key = authorization.replace("Bearer ", "")
390375

391-
# ✅ 1. Get filtered headers
392-
upstream_headers = get_forwardable_headers(request)
376+
if not authorization.startswith("Bearer "):
377+
logger.warning(f"Rejected request {request_id}: Invalid Authorization Format")
378+
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'")
393379

394-
logger.debug(f"Request {request_id} Forwarding Headers: {upstream_headers.keys()}")
380+
# Transparently forward the user's key
381+
api_key = authorization.replace("Bearer ", "")
395382

396-
# ✅ 2. Initialize Proxy with headers
397-
proxy = DeepSeekProxy(api_key=api_key, extra_headers=upstream_headers)
383+
logger.debug(f"using API Key: {api_key[:8]}... for request {request_id}")
384+
# Instantiate Proxy with the specific key (User's or Mock's)
385+
proxy = DeepSeekProxy(api_key=api_key)
398386

387+
# Parse Body if not injected
399388
if not body:
400389
try:
401390
raw_json = await request.json()
402391
body = GenerationRequest(**raw_json)
403392
except Exception as e:
404393
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
405394

395+
# --- [Auto-enable stream mode based on header] ---
406396
accept_header = request.headers.get("accept", "")
407397
if "text/event-stream" in accept_header and body.parameters:
398+
logger.info("SSE client detected, forcing incremental_output=True")
408399
body.parameters.incremental_output = True
409400

401+
402+
# --- [Mock Handling] ---
410403
if SERVER_STATE.is_mock_mode:
411-
# ... (Mock logic omitted for brevity, logic remains same)
404+
if body:
405+
# Shadow Traffic Logic (Optional validation against upstream)
406+
try:
407+
# We only perform shadow traffic if a key is actually available
408+
if _MOCK_ENV_API_KEY:
409+
shadow_proxy = DeepSeekProxy(api_key=_MOCK_ENV_API_KEY)
410+
# Fire and forget (or await if validation is strict)
411+
# await shadow_proxy.generate(body, f"shadow-{request_id}")
412+
except Exception:
413+
pass # Swallow shadow errors
414+
412415
try:
413416
raw_body = await request.json()
414417
SERVER_STATE.request_queue.put(raw_body)
@@ -419,6 +422,8 @@ async def generation(
419422
except Exception as e:
420423
return JSONResponse(status_code=500, content={"code": "MockError", "message": str(e)})
421424

425+
# --- [Production Handling] ---
426+
# Forward request to upstream
422427
return await proxy.generate(body, request_id)
423428

424429
@app.post("/siliconflow/models/{model_path:path}")
@@ -427,42 +432,38 @@ async def dynamic_path_generation(
427432
request: Request,
428433
authorization: Optional[str] = Header(None)
429434
):
435+
# 1. Strict Auth (No Mock Support)
430436
if not authorization or not authorization.startswith("Bearer "):
431437
raise HTTPException(status_code=401, detail="Invalid Authorization header")
432438

433439
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
440+
proxy = DeepSeekProxy(api_key=authorization.replace("Bearer ", ""))
434441

435-
# ✅ Header Forwarding Logic
436-
upstream_headers = get_forwardable_headers(request)
437-
proxy = DeepSeekProxy(api_key=authorization.replace("Bearer ", ""), extra_headers=upstream_headers)
438-
442+
# 2. Parse, Inject Model, and Validate
439443
try:
440444
payload = await request.json()
441-
payload["model"] = model_path
445+
payload["model"] = model_path # Force set model from URL
442446
body = GenerationRequest(**payload)
443447
except Exception as e:
444448
raise HTTPException(status_code=400, detail=f"Invalid Request: {e}")
445449

450+
# 3. Handle SSE
446451
if "text/event-stream" in request.headers.get("accept", "") and body.parameters:
447452
body.parameters.incremental_output = True
448453

454+
# 4. Generate
449455
return await proxy.generate(body, request_id)
450456

451457
@app.api_route("/{path_name:path}", methods=["GET", "POST", "DELETE", "PUT"])
452458
async def catch_all(path_name: str, request: Request):
459+
# Catch-all only valid in Mock Mode
453460
if SERVER_STATE.is_mock_mode:
454461
try:
455462
body = None
456463
if request.method in ["POST", "PUT"]:
457464
try: body = await request.json()
458465
except: pass
459-
# We forward headers in the mock record too, in case tests need to verify them
460-
req_record = {
461-
"path": f"/{path_name}",
462-
"method": request.method,
463-
"headers": dict(request.headers),
464-
"body": body
465-
}
466+
req_record = {"path": f"/{path_name}", "method": request.method, "headers": dict(request.headers), "body": body}
466467
SERVER_STATE.request_queue.put(req_record)
467468
response_data = SERVER_STATE.response_queue.get(timeout=5)
468469
response_json = json.loads(response_data) if isinstance(response_data, str) else response_data

0 commit comments

Comments
 (0)