Skip to content

Commit 762dc9d

Browse files
committed
Update mock_server.py
1 parent fa63343 commit 762dc9d

File tree

1 file changed

+40
-16
lines changed

1 file changed

+40
-16
lines changed

tests/mock_server.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ class Parameters(BaseModel):
118118
presence_penalty: Optional[float] = 0.0
119119
repetition_penalty: Optional[float] = 1.0
120120

121-
# Supported by Pydantic parsing but logic will reject them to pass tests
122121
logprobs: Optional[bool] = None
123122
top_logprobs: Optional[int] = None
124123

@@ -198,7 +197,6 @@ async def generate(
198197
has_input = req_data.input is not None
199198
has_content = False
200199
if has_input:
201-
# Strict mutual exclusion check
202200
if (
203201
req_data.input.messages is not None
204202
and req_data.input.prompt is not None
@@ -231,7 +229,6 @@ async def generate(
231229

232230
params = req_data.parameters
233231

234-
# Logprobs not supported
235232
if params.logprobs:
236233
return JSONResponse(
237234
status_code=400,
@@ -241,7 +238,6 @@ async def generate(
241238
},
242239
)
243240

244-
# max_tokens range check
245241
if params.max_tokens is not None and params.max_tokens < 1:
246242
return JSONResponse(
247243
status_code=400,
@@ -398,7 +394,6 @@ async def generate(
398394
"message": "<400> InternalError.Algo.InvalidParameter: Repetition_penalty should be greater than 0.0",
399395
},
400396
)
401-
# Limit repetition_penalty to (0, 2] range, clamp values that exceed 2
402397
clamped_penalty = min(params.repetition_penalty, 2.0)
403398
if clamped_penalty != params.repetition_penalty:
404399
logger.warning(
@@ -590,13 +585,11 @@ async def _stream_generator(
590585
current_tool_calls_payload = None
591586

592587
if delta and delta.tool_calls:
593-
# 只有 is_incremental=True 才直接发送 delta,否则我们发送完整列表
594588
if is_incremental:
595589
current_tool_calls_payload = [
596590
tc.model_dump() for tc in delta.tool_calls
597591
]
598592

599-
# 始终进行聚合,以备 incremental_output=False 使用
600593
for tc in delta.tool_calls:
601594
idx = tc.index
602595
if idx not in accumulated_tool_calls:
@@ -610,15 +603,13 @@ async def _stream_generator(
610603
},
611604
}
612605
else:
613-
# 合并字段
614606
if tc.id:
615607
accumulated_tool_calls[idx]["id"] = tc.id
616608
if tc.function.name:
617609
accumulated_tool_calls[idx]["function"][
618610
"name"
619611
] = tc.function.name
620612

621-
# 拼接参数
622613
if tc.function.arguments:
623614
accumulated_tool_calls[idx]["function"][
624615
"arguments"
@@ -630,14 +621,11 @@ async def _stream_generator(
630621
if is_incremental:
631622
content_to_send = delta_content
632623
reasoning_to_send = delta_reasoning
633-
# 如果是增量模式,使用上面提取的 delta payload
634624
final_tool_calls = current_tool_calls_payload
635625
else:
636626
content_to_send = full_text
637627
reasoning_to_send = full_reasoning
638-
# 如果是非增量模式,发送当前聚合的所有工具调用的列表
639628
if accumulated_tool_calls:
640-
# 将字典转回列表并按 index 排序
641629
final_tool_calls = sorted(
642630
accumulated_tool_calls.values(), key=lambda x: x["index"]
643631
)
@@ -794,6 +782,45 @@ async def validation_exception_handler(request, exc):
794782
path_str = ".".join([str(x) for x in loc if x != "body"])
795783
err_type = err.get("type")
796784

785+
input_value = err.get("input")
786+
787+
if err_type == "int_parsing":
788+
if isinstance(input_value, str):
789+
if param_name in ["max_tokens", "max_length"]:
790+
return JSONResponse(
791+
status_code=400,
792+
content={
793+
"code": "InvalidParameter",
794+
"message": f"<400> InternalError.Algo.InvalidParameter: Input should be a valid integer, unable to parse string as an integer: {path_str}",
795+
},
796+
)
797+
else:
798+
return JSONResponse(
799+
status_code=400,
800+
content={
801+
"code": "InvalidParameter",
802+
"message": f"<400> InternalError.Algo.InvalidParameter: Input should be a valid integer: {path_str}",
803+
},
804+
)
805+
806+
if err_type == "int_from_float":
807+
if param_name in ["max_tokens", "max_length"]:
808+
return JSONResponse(
809+
status_code=400,
810+
content={
811+
"code": "InvalidParameter",
812+
"message": f"<400> InternalError.Algo.InvalidParameter: Input should be a valid integer, got a number with a fractional part: {path_str}",
813+
},
814+
)
815+
else:
816+
return JSONResponse(
817+
status_code=400,
818+
content={
819+
"code": "InvalidParameter",
820+
"message": f"<400> InternalError.Algo.InvalidParameter: Input should be a valid integer: {path_str}",
821+
},
822+
)
823+
797824
if "stop" in loc:
798825
return JSONResponse(
799826
status_code=400,
@@ -812,13 +839,12 @@ async def validation_exception_handler(request, exc):
812839
},
813840
)
814841

815-
# Catch generic typing errors for content (including lists passed to str)
816842
if param_name == "content" or (len(loc) > 1 and loc[-2] == "content"):
817843
if (
818844
"valid string" in error_msg
819845
or "str" in error_msg
820846
or err_type == "string_type"
821-
or err_type == "list_type" # Added list_type for robustness
847+
or err_type == "list_type"
822848
or err_type == "dict_type"
823849
):
824850
return JSONResponse(
@@ -839,8 +865,6 @@ async def validation_exception_handler(request, exc):
839865
)
840866

841867
type_msg_map = {
842-
"int_parsing": "Input should be a valid integer",
843-
"int_from_float": "Input should be a valid integer",
844868
"float_parsing": "Input should be a valid number, unable to parse string as a number",
845869
"bool_parsing": "Input should be a valid boolean, unable to interpret input",
846870
"string_type": "Input should be a valid string",

0 commit comments

Comments
 (0)