Skip to content

Commit 0193ac9

Browse files
committed
Update mock_server.py
1 parent 5b540af commit 0193ac9

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

tests/mock_server.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import httpx
21
import os
32
import json
43
import time
@@ -9,6 +8,7 @@
98
from typing import List, Optional, Dict, Any, Union, AsyncGenerator
109
from contextlib import asynccontextmanager
1110

11+
import httpx
1212
import uvicorn
1313
from fastapi import FastAPI, HTTPException, Request, Header
1414
from fastapi.exceptions import RequestValidationError
@@ -17,17 +17,18 @@
1717
from pydantic import BaseModel, Field, AliasChoices, ConfigDict
1818
from openai import AsyncOpenAI, APIError, RateLimitError, AuthenticationError
1919

20+
# --- Logging Configuration ---
2021
logging.basicConfig(
2122
level=logging.DEBUG,
2223
format="%(asctime)s.%(msecs)03d | %(levelname)s | %(process)d | %(message)s",
2324
datefmt="%H:%M:%S",
2425
)
2526
logger = logging.getLogger("DeepSeekProxy")
2627

28+
# --- Constants & Environment Variables ---
2729
SILICON_FLOW_BASE_URL = os.getenv(
2830
"SILICON_FLOW_BASE_URL", "https://api.siliconflow.cn/v1"
2931
)
30-
3132
_MOCK_ENV_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
3233

3334
MODEL_MAPPING = {
@@ -37,11 +38,12 @@
3738
"deepseek-r1": "deepseek-ai/DeepSeek-R1",
3839
"default": "deepseek-ai/DeepSeek-V3",
3940
}
40-
DUMMY_KEY = "dummy-key"
4141

42+
DUMMY_KEY = "dummy-key"
4243
MAX_NUM_MSG_CURL_DUMP = 5
4344

4445

46+
# --- Server State Management ---
4547
class ServerState:
4648
_instance = None
4749

@@ -84,6 +86,7 @@ def snapshot(self):
8486
SERVER_STATE = ServerState.get_instance()
8587

8688

89+
# --- Pydantic Models ---
8790
class Message(BaseModel):
8891
role: str
8992
content: Optional[str] = ""
@@ -133,6 +136,7 @@ class GenerationRequest(BaseModel):
133136
parameters: Optional[Parameters] = Field(default_factory=Parameters)
134137

135138

139+
# --- DeepSeek Proxy Logic ---
136140
class DeepSeekProxy:
137141
def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None):
138142
if extra_headers is None:
@@ -178,6 +182,7 @@ async def generate(
178182
):
179183
params = req_data.parameters
180184

185+
# --- Validation Logic ---
181186
if params.n is not None:
182187
if not (1 <= params.n <= 4):
183188
return JSONResponse(
@@ -278,10 +283,11 @@ async def generate(
278283
},
279284
)
280285

286+
# --- Request Preparation ---
281287
target_model = self._get_mapped_model(req_data.model)
282288
messages = self._convert_input_to_messages(req_data.input)
283289

284-
# 核心修改:流式开启条件 = 参数要求增量 OR 开启思考 OR 外部强制流式(SSE)
290+
# Stream enablement condition = parameter requires incremental OR thinking enabled OR external forced stream (SSE)
285291
should_stream = (
286292
params.incremental_output or params.enable_thinking or force_stream
287293
)
@@ -291,7 +297,7 @@ async def generate(
291297
"messages": messages,
292298
"temperature": params.temperature,
293299
"top_p": params.top_p,
294-
"stream": should_stream, # 使用计算后的 stream 状态
300+
"stream": should_stream, # Use calculated stream status
295301
}
296302

297303
if params.response_format:
@@ -376,6 +382,7 @@ async def generate(
376382
if stop_list:
377383
openai_params["stop"] = stop_list
378384

385+
# --- Debug Logging (CURL generation) ---
379386
curl_headers = [
380387
'-H "Authorization: Bearer ${SILICONFLOW_API_KEY}"',
381388
"-H 'Content-Type: application/json'",
@@ -405,6 +412,7 @@ async def generate(
405412

406413
logger.debug(f"[Curl Command]\n{curl_cmd}")
407414

415+
# --- Execution ---
408416
try:
409417
if openai_params["stream"]:
410418
raw_resp = await self.client.chat.completions.with_raw_response.create(
@@ -414,9 +422,8 @@ async def generate(
414422
"X-SiliconCloud-Trace-Id", initial_request_id
415423
)
416424

417-
# 核心修改:传入 is_incremental 标志
418-
# 如果 params.incremental_output 为 True,则返回增量 (Delta)
419-
# 如果 params.incremental_output 为 False,则返回全量 (Accumulated)
425+
# If params.incremental_output is True, return incremental (Delta)
426+
# If params.incremental_output is False, return full text (Accumulated)
420427
return StreamingResponse(
421428
self._stream_generator(
422429
raw_resp.parse(),
@@ -501,17 +508,16 @@ async def _stream_generator(
501508
if chunk.choices and chunk.choices[0].finish_reason:
502509
finish_reason = chunk.choices[0].finish_reason
503510

504-
# === 核心逻辑修改 ===
505-
# 根据 is_incremental 决定发送的内容
511+
# Decide what content to send based on is_incremental
506512
if is_incremental:
507-
# 模式:增量 (Delta)
508-
# 修复 BUG:即使是最后一包(finish_reason != null),也只发 delta
509-
# 只有 delta 有内容才发内容,否则发空字符串。
513+
# Mode: Incremental (Delta)
514+
# Bug Fix: Even if it is the last packet (finish_reason != null), only send delta.
515+
# Only send content if delta has content, otherwise send empty string.
510516
content_to_send = delta_content
511517
reasoning_to_send = delta_reasoning
512518
else:
513-
# 模式:全量 (Accumulated)
514-
# 每一包都返回当前累积的完整文本
519+
# Mode: Full (Accumulated)
520+
# Every packet returns the full accumulated text so far.
515521
content_to_send = full_text
516522
reasoning_to_send = full_reasoning
517523

@@ -579,6 +585,7 @@ def _format_unary_response(self, completion, request_id: str):
579585
)
580586

581587

588+
# --- FastAPI Application Lifecycle ---
582589
@asynccontextmanager
583590
async def lifespan(app: FastAPI):
584591
stop_event = threading.Event()
@@ -725,7 +732,7 @@ async def generation(
725732
accept_header = request.headers.get("accept", "")
726733
dashscope_sse = request.headers.get("x-dashscope-sse", "").lower()
727734

728-
# 修改:检测是否需要强制流式传输(SSE),但不修改用户的 incremental_output 参数
735+
# but do not modify the user's incremental_output parameter.
729736
force_stream = False
730737
if (
731738
"text/event-stream" in accept_header or dashscope_sse == "enable"
@@ -734,7 +741,7 @@ async def generation(
734741
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream transport"
735742
)
736743
force_stream = True
737-
# 注意:这里删除了 body.parameters.incremental_output = True 这一行
744+
# Note: The line `body.parameters.incremental_output = True` was removed here.
738745

739746
if SERVER_STATE.is_mock_mode:
740747
if body:
@@ -760,7 +767,7 @@ async def generation(
760767
status_code=500, content={"code": "MockError", "message": str(e)}
761768
)
762769

763-
# 传入 force_stream
770+
# Pass force_stream to generate
764771
return await proxy.generate(body, request_id, force_stream=force_stream)
765772

766773
@app.post("/siliconflow/models/{model_path:path}")
@@ -779,7 +786,6 @@ async def dynamic_path_generation(
779786
accept_header = request.headers.get("accept", "")
780787
dashscope_sse = request.headers.get("x-dashscope-sse", "").lower()
781788

782-
# 修改:同上,使用 force_stream
783789
force_stream = False
784790
if (
785791
"text/event-stream" in accept_header or dashscope_sse == "enable"
@@ -830,6 +836,7 @@ async def catch_all(path_name: str, request: Request):
830836
return app
831837

832838

839+
# --- Mock Server Utilities ---
833840
def run_server_process(req_q, res_q, host="0.0.0.0", port=8000):
834841
if req_q and res_q:
835842
SERVER_STATE.set_queues(req_q, res_q)

0 commit comments

Comments
 (0)