Skip to content

Commit 5b540af

Browse files
committed
Update mock_server.py
Update mock_server.py
1 parent 50783ac commit 5b540af

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed

tests/mock_server.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,12 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
170170
messages.append({"role": "user", "content": input_data.prompt})
171171
return messages
172172

173-
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
173+
async def generate(
174+
self,
175+
req_data: GenerationRequest,
176+
initial_request_id: str,
177+
force_stream: bool = False,
178+
):
174179
params = req_data.parameters
175180

176181
if params.n is not None:
@@ -276,12 +281,17 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
276281
target_model = self._get_mapped_model(req_data.model)
277282
messages = self._convert_input_to_messages(req_data.input)
278283

284+
# 核心修改:流式开启条件 = 参数要求增量 OR 开启思考 OR 外部强制流式(SSE)
285+
should_stream = (
286+
params.incremental_output or params.enable_thinking or force_stream
287+
)
288+
279289
openai_params = {
280290
"model": target_model,
281291
"messages": messages,
282292
"temperature": params.temperature,
283293
"top_p": params.top_p,
284-
"stream": params.incremental_output or params.enable_thinking,
294+
"stream": should_stream, # 使用计算后的 stream 状态
285295
}
286296

287297
if params.response_format:
@@ -404,8 +414,15 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
404414
"X-SiliconCloud-Trace-Id", initial_request_id
405415
)
406416

417+
# 核心修改:传入 is_incremental 标志
418+
# 如果 params.incremental_output 为 True,则返回增量 (Delta)
419+
# 如果 params.incremental_output 为 False,则返回全量 (Accumulated)
407420
return StreamingResponse(
408-
self._stream_generator(raw_resp.parse(), trace_id),
421+
self._stream_generator(
422+
raw_resp.parse(),
423+
trace_id,
424+
is_incremental=params.incremental_output,
425+
),
409426
media_type="text/event-stream",
410427
headers={"X-SiliconCloud-Trace-Id": trace_id},
411428
)
@@ -438,7 +455,7 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
438455
)
439456

440457
async def _stream_generator(
441-
self, stream, request_id: str
458+
self, stream, request_id: str, is_incremental: bool
442459
) -> AsyncGenerator[str, None]:
443460
accumulated_usage = {
444461
"total_tokens": 0,
@@ -484,12 +501,19 @@ async def _stream_generator(
484501
if chunk.choices and chunk.choices[0].finish_reason:
485502
finish_reason = chunk.choices[0].finish_reason
486503

487-
if finish_reason != "null":
488-
content_to_send = full_text
489-
reasoning_to_send = full_reasoning
490-
else:
504+
# === 核心逻辑修改 ===
505+
# 根据 is_incremental 决定发送的内容
506+
if is_incremental:
507+
# 模式:增量 (Delta)
508+
# 修复 BUG:即使是最后一包(finish_reason != null),也只发 delta,
509+
# 只有 delta 有内容才发内容,否则发空字符串。
491510
content_to_send = delta_content
492511
reasoning_to_send = delta_reasoning
512+
else:
513+
# 模式:全量 (Accumulated)
514+
# 每一包都返回当前累积的完整文本
515+
content_to_send = full_text
516+
reasoning_to_send = full_reasoning
493517

494518
message_body = {
495519
"role": "assistant",
@@ -701,13 +725,16 @@ async def generation(
701725
accept_header = request.headers.get("accept", "")
702726
dashscope_sse = request.headers.get("x-dashscope-sse", "").lower()
703727

728+
# 修改:检测是否需要强制流式传输(SSE),但不修改用户的 incremental_output 参数
729+
force_stream = False
704730
if (
705731
"text/event-stream" in accept_header or dashscope_sse == "enable"
706732
) and body.parameters:
707733
logger.debug(
708-
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream"
734+
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream transport"
709735
)
710-
body.parameters.incremental_output = True
736+
force_stream = True
737+
# 注意:这里删除了 body.parameters.incremental_output = True 这一行
711738

712739
if SERVER_STATE.is_mock_mode:
713740
if body:
@@ -733,7 +760,8 @@ async def generation(
733760
status_code=500, content={"code": "MockError", "message": str(e)}
734761
)
735762

736-
return await proxy.generate(body, request_id)
763+
# 传入 force_stream
764+
return await proxy.generate(body, request_id, force_stream=force_stream)
737765

738766
@app.post("/siliconflow/models/{model_path:path}")
739767
async def dynamic_path_generation(
@@ -751,15 +779,17 @@ async def dynamic_path_generation(
751779
accept_header = request.headers.get("accept", "")
752780
dashscope_sse = request.headers.get("x-dashscope-sse", "").lower()
753781

782+
# 修改:同上,使用 force_stream
783+
force_stream = False
754784
if (
755785
"text/event-stream" in accept_header or dashscope_sse == "enable"
756786
) and body.parameters:
757787
logger.debug(
758-
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream"
788+
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream transport"
759789
)
760-
body.parameters.incremental_output = True
790+
force_stream = True
761791

762-
return await proxy.generate(body, request_id)
792+
return await proxy.generate(body, request_id, force_stream=force_stream)
763793

764794
@app.api_route("/{path_name:path}", methods=["GET", "POST", "DELETE", "PUT"])
765795
async def catch_all(path_name: str, request: Request):

0 commit comments

Comments
 (0)