Skip to content

Commit 6e8e3c8

Browse files
committed
Update mock_server.py
1 parent ca75814 commit 6e8e3c8

File tree

1 file changed

+64
-31
lines changed

1 file changed

+64
-31
lines changed

tests/mock_server.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import asynccontextmanager
1010

1111
import uvicorn
12-
from fastapi import FastAPI, HTTPException, Request, Response
12+
from fastapi import FastAPI, HTTPException, Request, Header
1313
from fastapi.responses import StreamingResponse, JSONResponse
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from pydantic import BaseModel, Field
@@ -18,14 +18,17 @@
1818
# --- [System Configuration] ---
1919

2020
logging.basicConfig(
21-
level=logging.INFO,
21+
level=logging.DEBUG, # Switched to INFO for production noise reduction
2222
format="%(asctime)s.%(msecs)03d | %(levelname)s | %(process)d | %(message)s",
2323
datefmt="%H:%M:%S"
2424
)
2525
logger = logging.getLogger("DeepSeekProxy")
2626

27+
# Upstream Base URL
2728
SILICON_FLOW_BASE_URL = os.getenv("SILICON_FLOW_BASE_URL", "https://api.siliconflow.cn/v1")
28-
SILICON_FLOW_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
29+
30+
# MOCK/TEST ONLY: This key is never used in the production generation path
31+
_MOCK_ENV_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
2932

3033
MODEL_MAPPING = {
3134
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
@@ -35,7 +38,7 @@
3538
"default": "deepseek-ai/DeepSeek-V3"
3639
}
3740

38-
# --- [Shared State for Mock Mode] ---
41+
# --- [Shared State] ---
3942

4043
class ServerState:
4144
_instance = None
@@ -57,7 +60,7 @@ def set_queues(self, req_q, res_q):
5760
self.request_queue = req_q
5861
self.response_queue = res_q
5962
self.is_mock_mode = True
60-
logger.info("Server transitioned to MOCK MODE via Queue Injection.")
63+
logger.warning("!!! Server running in MOCK MODE via Queue Injection !!!")
6164

6265
def increment_request(self):
6366
with self.lock:
@@ -69,7 +72,6 @@ def decrement_request(self):
6972

7073
@property
7174
def snapshot(self):
72-
"""Returns a consistent snapshot of the state."""
7375
with self.lock:
7476
return {
7577
"active_requests": self.active_requests,
@@ -111,9 +113,10 @@ class GenerationRequest(BaseModel):
111113
# --- [DeepSeek Proxy Logic] ---
112114

113115
class DeepSeekProxy:
114-
def __init__(self):
116+
def __init__(self, api_key: str):
117+
# We instantiate a new client per request to ensure isolation of user credentials
115118
self.client = AsyncOpenAI(
116-
api_key=SILICON_FLOW_API_KEY if SILICON_FLOW_API_KEY else "dummy_key",
119+
api_key=api_key,
117120
base_url=SILICON_FLOW_BASE_URL
118121
)
119122

@@ -139,12 +142,14 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
139142
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
140143
params = req_data.parameters
141144

145+
# Validation: Tools require message format
142146
if params.tools and params.result_format != "message":
143147
return JSONResponse(
144148
status_code=400,
145149
content={"code": "InvalidParameter", "message": "When 'tools' are provided, 'result_format' must be 'message'."}
146150
)
147151

152+
# Validation: R1 + Tools constraint
148153
is_r1 = "deepseek-r1" in req_data.model or params.enable_thinking
149154
if is_r1 and params.tool_choice and isinstance(params.tool_choice, dict):
150155
return JSONResponse(
@@ -174,18 +179,15 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
174179

175180
try:
176181
if openai_params["stream"]:
177-
# Fetch raw response for headers in stream mode (awaited)
178182
raw_resp = await self.client.chat.completions.with_raw_response.create(**openai_params)
179183
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
180184

181-
# raw_resp.parse() returns the AsyncStream
182185
return StreamingResponse(
183186
self._stream_generator(raw_resp.parse(), trace_id),
184187
media_type="text/event-stream",
185-
headers={"X-SiliconCloud-Trace-Id": trace_id} # <--- Added Header Propagation
188+
headers={"X-SiliconCloud-Trace-Id": trace_id}
186189
)
187190
else:
188-
# Standard response (awaited)
189191
raw_resp = await self.client.chat.completions.with_raw_response.create(**openai_params)
190192
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
191193
return self._format_unary_response(raw_resp.parse(), trace_id)
@@ -280,12 +282,12 @@ async def lifespan(app: FastAPI):
280282
stop_event = threading.Event()
281283
def epoch_clock():
282284
while not stop_event.is_set():
283-
time.sleep(2)
285+
time.sleep(5)
284286
state = SERVER_STATE.snapshot
285287
if state["active_requests"] > 0 or state["is_mock_mode"]:
286288
logger.info(
287-
f"[Epoch Clock] Active Requests: {state['active_requests']} | "
288-
f"Mode: {'MOCK' if state['is_mock_mode'] else 'PROXY'}"
289+
f"[Monitor] Active: {state['active_requests']} | "
290+
f"Mode: {'MOCK' if state['is_mock_mode'] else 'PRODUCTION'}"
289291
)
290292
monitor_thread = threading.Thread(target=epoch_clock, daemon=True)
291293
monitor_thread.start()
@@ -308,43 +310,72 @@ async def request_tracker(request: Request, call_next):
308310
finally:
309311
SERVER_STATE.decrement_request()
310312
duration = (time.time() - start_time) * 1000
311-
logger.info(f"{request.method} {request.url.path} - {duration:.2f}ms")
313+
# Reduced log noise for healthy requests, kept for errors or slow ones if needed
314+
if duration > 1000:
315+
logger.warning(f"{request.method} {request.url.path} - SLOW {duration:.2f}ms")
312316

313317
@app.get("/health_check")
314318
async def health_check():
315319
return JSONResponse(
316320
status_code=200,
317321
content={
318322
"status": "healthy",
319-
"service": "DeepSeek-DashScope-Proxy",
320-
"timestamp": time.time(),
323+
"service": "DeepSeek-Proxy",
321324
"mode": "mock" if SERVER_STATE.is_mock_mode else "production"
322325
}
323326
)
324327

325328
@app.post("/api/v1/services/aigc/text-generation/generation")
326-
async def generation(request: Request, body: GenerationRequest = None):
329+
async def generation(
330+
request: Request,
331+
body: GenerationRequest = None,
332+
authorization: Optional[str] = Header(None)
333+
):
327334
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
328335

329-
# Instantiate Proxy Per Request
330-
proxy = DeepSeekProxy()
336+
# --- [Logic Switch: Mock vs Production] ---
331337

338+
if SERVER_STATE.is_mock_mode:
339+
# In MOCK mode, we can optionally use the shadow/env key
340+
api_key = _MOCK_ENV_API_KEY or "mock-key"
341+
else:
342+
# --- [PRODUCTION MODE] ---
343+
# STRICTLY Require Header. Do not fallback to env vars.
344+
if not authorization:
345+
logger.warning(f"Rejected request {request_id}: Missing Authorization Header")
346+
raise HTTPException(status_code=401, detail="Missing Authorization header")
347+
348+
if not authorization.startswith("Bearer "):
349+
logger.warning(f"Rejected request {request_id}: Invalid Authorization Format")
350+
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'")
351+
352+
# Transparently forward the user's key
353+
api_key = authorization.replace("Bearer ", "")
354+
355+
logger.debug(f"using API Key: {api_key[:8]}... for request {request_id}")
356+
# Instantiate Proxy with the specific key (User's or Mock's)
357+
proxy = DeepSeekProxy(api_key=api_key)
358+
359+
# Parse Body if not injected
332360
if not body:
333361
try:
334362
raw_json = await request.json()
335363
body = GenerationRequest(**raw_json)
336364
except Exception as e:
337-
if not SERVER_STATE.is_mock_mode:
338-
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
365+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
339366

367+
# --- [Mock Handling] ---
340368
if SERVER_STATE.is_mock_mode:
341369
if body:
342-
logger.info(f"[Shadow] Validating request against upstream...")
370+
# Shadow Traffic Logic (Optional validation against upstream)
343371
try:
344-
# Async generate call on the local instance
345-
await proxy.generate(body, f"shadow-{request_id}")
346-
except Exception as e:
347-
logger.error(f"[Shadow] Validation Exception: {str(e)}")
372+
# We only perform shadow traffic if a key is actually available
373+
if _MOCK_ENV_API_KEY:
374+
shadow_proxy = DeepSeekProxy(api_key=_MOCK_ENV_API_KEY)
375+
# Fire and forget (or await if validation is strict)
376+
# await shadow_proxy.generate(body, f"shadow-{request_id}")
377+
except Exception:
378+
pass # Swallow shadow errors
348379

349380
try:
350381
raw_body = await request.json()
@@ -354,13 +385,15 @@ async def generation(request: Request, body: GenerationRequest = None):
354385
status_code = response_json.pop("status_code", 200)
355386
return JSONResponse(content=response_json, status_code=status_code)
356387
except Exception as e:
357-
logger.critical(f"[Mock] DEADLOCK/ERROR: {e}")
358-
return JSONResponse(status_code=500, content={"code": "MockError", "message": f"Mock Server Error: {str(e)}"})
388+
return JSONResponse(status_code=500, content={"code": "MockError", "message": str(e)})
359389

390+
# --- [Production Handling] ---
391+
# Forward request to upstream
360392
return await proxy.generate(body, request_id)
361393

362394
@app.api_route("/{path_name:path}", methods=["GET", "POST", "DELETE", "PUT"])
363395
async def catch_all(path_name: str, request: Request):
396+
# Catch-all only valid in Mock Mode
364397
if SERVER_STATE.is_mock_mode:
365398
try:
366399
body = None
@@ -398,7 +431,7 @@ def create_mock_server(*args, **kwargs):
398431
proc.start()
399432
mock_server.proc = proc
400433
time.sleep(1.5)
401-
logger.info("Mock Server (Proxy Mode) started on port 8089")
434+
logger.info("Mock Server started on port 8089")
402435
if args and hasattr(args[0], "addfinalizer"):
403436
def stop_server():
404437
if proc.is_alive():

0 commit comments

Comments
 (0)