Skip to content

Commit f1e58f0

Browse files
committed
trace id
1 parent 7009c2c commit f1e58f0

File tree

1 file changed

+23
-96
lines changed

1 file changed

+23
-96
lines changed

tests/mock_server.py

Lines changed: 23 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import asynccontextmanager
1010

1111
import uvicorn
12-
from fastapi import FastAPI, HTTPException, Request
12+
from fastapi import FastAPI, HTTPException, Request, Response
1313
from fastapi.responses import StreamingResponse, JSONResponse
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from pydantic import BaseModel, Field
@@ -67,7 +67,6 @@ def decrement_request(self):
6767
with self.lock:
6868
self.active_requests -= 1
6969

70-
# [FIX 1] 添加线程安全的快照读取,确保并发状态一致性
7170
@property
7271
def snapshot(self):
7372
"""Returns a consistent snapshot of the state."""
@@ -101,9 +100,7 @@ class Parameters(BaseModel):
101100
stop: Optional[Union[str, List[str]]] = None
102101
enable_thinking: bool = False
103102
thinking_budget: Optional[int] = None
104-
# [ADDED] Tools Support
105103
tools: Optional[List[Dict[str, Any]]] = None
106-
# Allowed: "none", "auto", "required" (str) OR {"type": "function", ...} (dict)
107104
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
108105

109106
class GenerationRequest(BaseModel):
@@ -139,23 +136,15 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
139136
messages.append({"role": "user", "content": input_data.prompt})
140137
return messages
141138

142-
async def generate(self, req_data: GenerationRequest, request_id: str):
143-
"""
144-
Standard generation logic with strict invariant checks.
145-
"""
139+
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
146140
params = req_data.parameters
147141

148-
# --- [Invariant Checks] ---
149-
# 1. Format Constraint (Predicate A)
150142
if params.tools and params.result_format != "message":
151143
return JSONResponse(
152144
status_code=400,
153145
content={"code": "InvalidParameter", "message": "When 'tools' are provided, 'result_format' must be 'message'."}
154146
)
155147

156-
# 2. R1 Orthogonality (Predicate B)
157-
# DeepSeek R1 Thinking Mode is mutually exclusive with FORCED SPECIFIC tool choice (Dict).
158-
# However, abstract constraints like "required" (String) are allowed.
159148
is_r1 = "deepseek-r1" in req_data.model or params.enable_thinking
160149
if is_r1 and params.tool_choice and isinstance(params.tool_choice, dict):
161150
return JSONResponse(
@@ -174,11 +163,9 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
174163
"stream": params.incremental_output or params.enable_thinking,
175164
}
176165

177-
# [ADDED] Pass tools to upstream if present
178166
if params.tools:
179167
openai_params["tools"] = params.tools
180168
if params.tool_choice:
181-
# This will pass "required" (str) effectively to OpenAI/SiliconFlow
182169
openai_params["tool_choice"] = params.tool_choice
183170

184171
if params.max_tokens: openai_params["max_tokens"] = params.max_tokens
@@ -187,13 +174,17 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
187174

188175
try:
189176
if openai_params["stream"]:
177+
# Fetch raw response for headers in stream mode
178+
raw_resp = self.client.chat.completions.with_raw_response.create(**openai_params)
179+
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
190180
return StreamingResponse(
191-
self._stream_generator(openai_params, request_id),
181+
self._stream_generator(raw_resp.parse(), trace_id),
192182
media_type="text/event-stream"
193183
)
194184
else:
195-
completion = self.client.chat.completions.create(**openai_params)
196-
return self._format_unary_response(completion, request_id)
185+
raw_resp = self.client.chat.completions.with_raw_response.create(**openai_params)
186+
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
187+
return self._format_unary_response(raw_resp.parse(), trace_id)
197188

198189
except APIError as e:
199190
logger.error(f"Upstream API Error: {str(e)}")
@@ -203,20 +194,10 @@ async def generate(self, req_data: GenerationRequest, request_id: str):
203194

204195
return JSONResponse(
205196
status_code=e.status_code or 500,
206-
content={"code": error_code, "message": str(e), "request_id": request_id}
197+
content={"code": error_code, "message": str(e), "request_id": initial_request_id}
207198
)
208199

209-
async def _stream_generator(self, openai_params: Dict, request_id: str) -> AsyncGenerator[str, None]:
210-
if "stream_options" not in openai_params:
211-
openai_params["stream_options"] = {"include_usage": True}
212-
213-
try:
214-
stream = self.client.chat.completions.create(**openai_params)
215-
except Exception as e:
216-
logger.error(f"Stream creation failed: {e}")
217-
yield f"data: {json.dumps({'code': 'StreamError', 'message': str(e)}, ensure_ascii=False)}\n\n"
218-
return
219-
200+
async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str, None]:
220201
accumulated_usage = {
221202
"total_tokens": 0, "input_tokens": 0, "output_tokens": 0,
222203
"output_tokens_details": {"text_tokens": 0, "reasoning_tokens": 0}
@@ -234,34 +215,21 @@ async def _stream_generator(self, openai_params: Dict, request_id: str) -> Async
234215
accumulated_usage["output_tokens_details"]["text_tokens"] = accumulated_usage["output_tokens"] - accumulated_usage["output_tokens_details"]["reasoning_tokens"]
235216

236217
delta = chunk.choices[0].delta if chunk.choices else None
237-
238218
content = delta.content if delta and delta.content else ""
239219
reasoning = getattr(delta, "reasoning_content", "") if delta else ""
240220

241221
tool_calls = None
242222
if delta and delta.tool_calls:
243-
# Forward the raw list of tool call chunks
244-
# Note: model_dump() is retained per original design, ensuring Pydantic serialization
245223
tool_calls = [tc.model_dump() for tc in delta.tool_calls]
246224

247225
if chunk.choices and chunk.choices[0].finish_reason:
248226
finish_reason = chunk.choices[0].finish_reason
249227

250-
message_body = {
251-
"role": "assistant",
252-
"content": content,
253-
"reasoning_content": reasoning
254-
}
255-
if tool_calls:
256-
message_body["tool_calls"] = tool_calls
228+
message_body = {"role": "assistant", "content": content, "reasoning_content": reasoning}
229+
if tool_calls: message_body["tool_calls"] = tool_calls
257230

258231
response_body = {
259-
"output": {
260-
"choices": [{
261-
"message": message_body,
262-
"finish_reason": finish_reason
263-
}]
264-
},
232+
"output": {"choices": [{"message": message_body, "finish_reason": finish_reason}]},
265233
"usage": accumulated_usage,
266234
"request_id": request_id
267235
}
@@ -271,12 +239,11 @@ async def _stream_generator(self, openai_params: Dict, request_id: str) -> Async
271239
def _format_unary_response(self, completion, request_id: str):
272240
choice = completion.choices[0]
273241
msg = choice.message
274-
275242
usage_data = {
276243
"total_tokens": completion.usage.total_tokens,
277244
"input_tokens": completion.usage.prompt_tokens,
278245
"output_tokens": completion.usage.completion_tokens,
279-
"output_tokens_details": {"text_tokens": 0, "reasoning_tokens": 0}
246+
"output_tokens_details": {"text_tokens": 0, "reasoning_tokens": 0}
280247
}
281248
details = getattr(completion.usage, "completion_tokens_details", None)
282249
if details:
@@ -292,12 +259,7 @@ def _format_unary_response(self, completion, request_id: str):
292259
message_body["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
293260

294261
response_body = {
295-
"output": {
296-
"choices": [{
297-
"message": message_body,
298-
"finish_reason": choice.finish_reason
299-
}]
300-
},
262+
"output": {"choices": [{"message": message_body, "finish_reason": choice.finish_reason}]},
301263
"usage": usage_data,
302264
"request_id": request_id
303265
}
@@ -311,15 +273,12 @@ async def lifespan(app: FastAPI):
311273
def epoch_clock():
312274
while not stop_event.is_set():
313275
time.sleep(2)
314-
315-
# [FIX 1 Usage & FIX 2] 使用快照读取状态,并将日志级别改为 INFO
316276
state = SERVER_STATE.snapshot
317277
if state["active_requests"] > 0 or state["is_mock_mode"]:
318278
logger.info(
319279
f"[Epoch Clock] Active Requests: {state['active_requests']} | "
320280
f"Mode: {'MOCK' if state['is_mock_mode'] else 'PROXY'}"
321281
)
322-
323282
monitor_thread = threading.Thread(target=epoch_clock, daemon=True)
324283
monitor_thread.start()
325284
yield
@@ -344,12 +303,8 @@ async def request_tracker(request: Request, call_next):
344303
duration = (time.time() - start_time) * 1000
345304
logger.info(f"{request.method} {request.url.path} - {duration:.2f}ms")
346305

347-
# --- [New Endpoint] Health Check ---
348306
@app.get("/health_check")
349307
async def health_check():
350-
"""
351-
Liveness probe verifying server status and current mode.
352-
"""
353308
return JSONResponse(
354309
status_code=200,
355310
content={
@@ -372,39 +327,25 @@ async def generation(request: Request, body: GenerationRequest = None):
372327
if not SERVER_STATE.is_mock_mode:
373328
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
374329

375-
# === [Logic Branch]: Mock Mode with Shadow Verification ===
376330
if SERVER_STATE.is_mock_mode:
377331
if body:
378-
logger.info(f"[Shadow] Validating request against upstream: {body.model_dump_json(exclude_none=True)}")
332+
logger.info(f"[Shadow] Validating request against upstream...")
379333
try:
380-
shadow_resp = await proxy.generate(body, f"shadow-{request_id}")
381-
if isinstance(shadow_resp, StreamingResponse):
382-
async for _ in shadow_resp.body_iterator: pass
383-
logger.info("[Shadow] Upstream stream validation: PASSED")
384-
elif isinstance(shadow_resp, JSONResponse):
385-
status = shadow_resp.status_code
386-
if 200 <= status < 300:
387-
logger.info(f"[Shadow] Upstream unary validation: PASSED (Status {status})")
388-
else:
389-
logger.warning(f"[Shadow] Upstream validation FAILED (Status {status})")
334+
await proxy.generate(body, f"shadow-{request_id}")
390335
except Exception as e:
391336
logger.error(f"[Shadow] Validation Exception: {str(e)}")
392337

393338
try:
394339
raw_body = await request.json()
395340
SERVER_STATE.request_queue.put(raw_body)
396341
response_data = SERVER_STATE.response_queue.get(timeout=10)
397-
if isinstance(response_data, str):
398-
response_json = json.loads(response_data)
399-
else:
400-
response_json = response_data
342+
response_json = json.loads(response_data) if isinstance(response_data, str) else response_data
401343
status_code = response_json.pop("status_code", 200)
402344
return JSONResponse(content=response_json, status_code=status_code)
403345
except Exception as e:
404346
logger.critical(f"[Mock] DEADLOCK/ERROR: {e}")
405347
return JSONResponse(status_code=500, content={"code": "MockError", "message": f"Mock Server Error: {str(e)}"})
406348

407-
# === [Logic Branch]: Production Proxy Mode ===
408349
return await proxy.generate(body, request_id)
409350

410351
@app.api_route("/{path_name:path}", methods=["GET", "POST", "DELETE", "PUT"])
@@ -413,21 +354,12 @@ async def catch_all(path_name: str, request: Request):
413354
try:
414355
body = None
415356
if request.method in ["POST", "PUT"]:
416-
try:
417-
body = await request.json()
357+
try: body = await request.json()
418358
except: pass
419-
req_record = {
420-
"path": f"/{path_name}",
421-
"method": request.method,
422-
"headers": dict(request.headers),
423-
"body": body
424-
}
359+
req_record = {"path": f"/{path_name}", "method": request.method, "headers": dict(request.headers), "body": body}
425360
SERVER_STATE.request_queue.put(req_record)
426361
response_data = SERVER_STATE.response_queue.get(timeout=5)
427-
if isinstance(response_data, str):
428-
response_json = json.loads(response_data)
429-
else:
430-
response_json = response_data
362+
response_json = json.loads(response_data) if isinstance(response_data, str) else response_data
431363
status_code = response_json.pop("status_code", 200)
432364
return JSONResponse(content=response_json, status_code=status_code)
433365
except Exception as e:
@@ -451,21 +383,16 @@ def __init__(self) -> None:
451383

452384
def create_mock_server(*args, **kwargs):
453385
mock_server = MockServer()
454-
proc = multiprocessing.Process(
455-
target=run_server_process,
456-
args=(mock_server.requests, mock_server.responses, "0.0.0.0", 8089)
457-
)
386+
proc = multiprocessing.Process(target=run_server_process, args=(mock_server.requests, mock_server.responses, "0.0.0.0", 8089))
458387
proc.start()
459388
mock_server.proc = proc
460389
time.sleep(1.5)
461390
logger.info("Mock Server (Proxy Mode) started on port 8089")
462-
463391
if args and hasattr(args[0], "addfinalizer"):
464392
def stop_server():
465393
if proc.is_alive():
466394
proc.terminate()
467395
proc.join()
468-
logger.info("Mock Server stopped")
469396
args[0].addfinalizer(stop_server)
470397
return mock_server
471398

0 commit comments

Comments
 (0)