Skip to content

Commit 0731257

Browse files
committed
Update mock_server.py
1 parent fbb6476 commit 0731257

File tree

1 file changed

+81
-31
lines changed

1 file changed

+81
-31
lines changed

tests/mock_server.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class Parameters(BaseModel):
118118
presence_penalty: Optional[float] = 0.0
119119
repetition_penalty: Optional[float] = 1.0
120120

121+
# Supported by Pydantic parsing but logic will reject them to pass tests
122+
logprobs: Optional[bool] = None
123+
top_logprobs: Optional[int] = None
124+
121125
stop: Optional[Union[str, List[str]]] = None
122126
stop_words: Optional[List[Dict[str, Any]]] = None
123127
enable_thinking: bool = False
@@ -156,7 +160,6 @@ def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None)
156160
)
157161

158162
def _get_mapped_model(self, request_model: str) -> str:
159-
# Default is kept here for internal logic, but strict check is added in generate()
160163
return MODEL_MAPPING.get(request_model, MODEL_MAPPING["default"])
161164

162165
def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, str]]:
@@ -194,15 +197,22 @@ async def generate(
194197
has_input = req_data.input is not None
195198
has_content = False
196199
if has_input:
200+
# Strict mutual exclusion check
201+
if req_data.input.messages is not None and req_data.input.prompt is not None:
202+
return JSONResponse(
203+
status_code=400,
204+
content={
205+
"code": "InvalidParameter",
206+
"message": '<400> InternalError.Algo.InvalidParameter: Only one of the parameters "prompt" and "messages" can be present',
207+
},
208+
)
209+
197210
has_content = (
198211
bool(req_data.input.messages)
199212
or bool(req_data.input.prompt)
200213
or bool(req_data.input.history)
201214
)
202215

203-
# 如果用户显式传了 prompt="" (空字符串),按照 DashScope 协议这是明确的非法参数,
204-
# 此时即使有 history,也应该强制报错。
205-
# 注意:如果 prompt 是 None (即没传该字段),则不受此逻辑影响,保留了 history 的原有行为。
206216
if req_data.input.prompt == "" and not req_data.input.messages:
207217
has_content = False
208218

@@ -217,6 +227,26 @@ async def generate(
217227

218228
params = req_data.parameters
219229

230+
# Logprobs not supported
231+
if params.logprobs:
232+
return JSONResponse(
233+
status_code=400,
234+
content={
235+
"code": "InvalidParameter",
236+
"message": "<400> InternalError.Algo.InvalidParameter: The parameters `logprobs` is not supported.",
237+
},
238+
)
239+
240+
# max_tokens range check
241+
if params.max_tokens is not None and params.max_tokens < 1:
242+
return JSONResponse(
243+
status_code=400,
244+
content={
245+
"code": "InvalidParameter",
246+
"message": f"<400> InternalError.Algo.InvalidParameter: Range of max_tokens should be [1, 2147483647]",
247+
},
248+
)
249+
220250
# --- Validation Logic ---
221251
if params.n is not None:
222252
if not (1 <= params.n <= 4):
@@ -238,6 +268,16 @@ async def generate(
238268
},
239269
)
240270

271+
if params.seed is not None:
272+
if not (0 <= params.seed <= 9223372036854775807):
273+
return JSONResponse(
274+
status_code=400,
275+
content={
276+
"code": "InvalidParameter",
277+
"message": "<400> InternalError.Algo.InvalidParameter: Range of seed should be [0, 9223372036854775807]",
278+
},
279+
)
280+
241281
if params.response_format:
242282
rf_type = params.response_format.get("type")
243283
if rf_type and rf_type not in ["json_object", "text"]:
@@ -322,7 +362,6 @@ async def generate(
322362
target_model = self._get_mapped_model(req_data.model)
323363
messages = self._convert_input_to_messages(req_data.input)
324364

325-
# Stream enablement condition = parameter requires incremental OR thinking enabled OR external forced stream (SSE)
326365
should_stream = (
327366
params.incremental_output or params.enable_thinking or force_stream
328367
)
@@ -332,7 +371,7 @@ async def generate(
332371
"messages": messages,
333372
"temperature": params.temperature,
334373
"top_p": params.top_p,
335-
"stream": should_stream, # Use calculated stream status
374+
"stream": should_stream,
336375
}
337376

338377
if params.response_format:
@@ -457,8 +496,6 @@ async def generate(
457496
"X-SiliconCloud-Trace-Id", initial_request_id
458497
)
459498

460-
# If params.incremental_output is True, return incremental (Delta)
461-
# If params.incremental_output is False, return full text (Accumulated)
462499
return StreamingResponse(
463500
self._stream_generator(
464501
raw_resp.parse(),
@@ -543,16 +580,10 @@ async def _stream_generator(
543580
if chunk.choices and chunk.choices[0].finish_reason:
544581
finish_reason = chunk.choices[0].finish_reason
545582

546-
# Decide what content to send based on is_incremental
547583
if is_incremental:
548-
# Mode: Incremental (Delta)
549-
# Bug Fix: Even if it is the last packet (finish_reason != null), only send delta.
550-
# Only send content if delta has content, otherwise send empty string.
551584
content_to_send = delta_content
552585
reasoning_to_send = delta_reasoning
553586
else:
554-
# Mode: Full (Accumulated)
555-
# Every packet returns the full accumulated text so far.
556587
content_to_send = full_text
557588
reasoning_to_send = full_reasoning
558589

@@ -702,18 +733,19 @@ async def validation_exception_handler(request, exc):
702733
error_msg = err.get("msg", "Invalid parameter")
703734
loc = err.get("loc", [])
704735
param_name = loc[-1] if loc else "unknown"
736+
path_str = ".".join([str(x) for x in loc if x != "body"])
737+
err_type = err.get("type")
705738

706739
if "stop" in loc:
707740
return JSONResponse(
708741
status_code=400,
709742
content={
710743
"code": "InvalidParameter",
711-
# 这里必须严格匹配测试用例期望的字符串
712744
"message": "<400> InternalError.Algo.InvalidParameter: Input should be a valid list: parameters.stop.list[any] & Input should be a valid string: parameters.stop.str",
713745
},
714746
)
715747

716-
if "model" in loc and err.get("type") == "missing":
748+
if "model" in loc and err_type == "missing":
717749
return JSONResponse(
718750
status_code=400,
719751
content={
@@ -722,19 +754,14 @@ async def validation_exception_handler(request, exc):
722754
},
723755
)
724756

725-
if err.get("type") == "int_parsing":
726-
# Reconstruct path "parameters.max_length" from ["body", "parameters", "max_length"]
727-
path_str = ".".join([str(x) for x in loc if x != "body"])
728-
return JSONResponse(
729-
status_code=400,
730-
content={
731-
"code": "InvalidParameter",
732-
"message": f"<400> InternalError.Algo.InvalidParameter: {error_msg}: {path_str}",
733-
},
734-
)
735-
757+
# Catch generic typing errors for content (including lists passed to str)
736758
if param_name == "content":
737-
if "valid string" in error_msg or "str" in error_msg:
759+
if (
760+
"valid string" in error_msg
761+
or "str" in error_msg
762+
or err_type == "string_type"
763+
or err_type == "list_type" # Added list_type for robustness
764+
):
738765
return JSONResponse(
739766
status_code=400,
740767
content={
@@ -743,6 +770,32 @@ async def validation_exception_handler(request, exc):
743770
},
744771
)
745772

773+
if "response_format" in loc and err_type == "dict_type":
774+
return JSONResponse(
775+
status_code=400,
776+
content={
777+
"code": "InvalidParameter",
778+
"message": "<400> InternalError.Algo.InvalidParameter: Unknown format of response_format, response_format should be a dict, includes 'type' and an optional key 'json_schema'. The response_format type from user is <class 'str'>.",
779+
},
780+
)
781+
782+
type_msg_map = {
783+
"int_parsing": "Input should be a valid integer",
784+
"int_from_float": "Input should be a valid integer",
785+
"float_parsing": "Input should be a valid number, unable to parse string as a number",
786+
"bool_parsing": "Input should be a valid boolean, unable to interpret input",
787+
"string_type": "Input should be a valid string",
788+
}
789+
790+
if err_type in type_msg_map:
791+
return JSONResponse(
792+
status_code=400,
793+
content={
794+
"code": "InvalidParameter",
795+
"message": f"<400> InternalError.Algo.InvalidParameter: {type_msg_map[err_type]}: {path_str}",
796+
},
797+
)
798+
746799
logger.error(f"Validation Error: {exc.errors()}")
747800

748801
return JSONResponse(
@@ -797,7 +850,6 @@ async def generation(
797850
accept_header = request.headers.get("accept", "")
798851
dashscope_sse = request.headers.get("x-dashscope-sse", "").lower()
799852

800-
# but do not modify the user's incremental_output parameter.
801853
force_stream = False
802854
if (
803855
"text/event-stream" in accept_header or dashscope_sse == "enable"
@@ -806,7 +858,6 @@ async def generation(
806858
f"SSE detected (Accept: {accept_header}, X-DashScope-SSE: {dashscope_sse}), enabling stream transport"
807859
)
808860
force_stream = True
809-
# Note: The line `body.parameters.incremental_output = True` was removed here.
810861

811862
if SERVER_STATE.is_mock_mode:
812863
if body:
@@ -832,7 +883,6 @@ async def generation(
832883
status_code=500, content={"code": "MockError", "message": str(e)}
833884
)
834885

835-
# Pass force_stream to generate
836886
return await proxy.generate(body, request_id, force_stream=force_stream)
837887

838888
@app.post("/siliconflow/models/{model_path:path}")

0 commit comments

Comments
 (0)