Skip to content

Commit 7966d27

Browse files
committed
Update mock_server.py
Update mock_server.py
1 parent 9d52d21 commit 7966d27

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

tests/mock_server.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
"deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
3838
"deepseek-r1": "deepseek-ai/DeepSeek-R1",
3939
"default": "deepseek-ai/DeepSeek-V3",
40+
"pre-siliconflow/deepseek-v3": "deepseek-ai/DeepSeek-V3",
41+
"pre-siliconflow/deepseek-v3.1": "deepseek-ai/DeepSeek-V3.1",
42+
"pre-siliconflow/deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
43+
"pre-siliconflow/deepseek-r1": "deepseek-ai/DeepSeek-R1",
4044
}
4145

4246
DUMMY_KEY = "dummy-key"
@@ -165,10 +169,11 @@ def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None)
165169
def _get_mapped_model(self, request_model: str) -> str:
166170
return MODEL_MAPPING.get(request_model, MODEL_MAPPING["default"])
167171

168-
def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, str]]:
172+
def _convert_input_to_messages(
173+
self, input_data: InputData
174+
) -> List[Dict[str, Any]]:
169175
if input_data.messages:
170-
return [m.model_dump() for m in input_data.messages]
171-
176+
return [m.model_dump(exclude_none=True) for m in input_data.messages]
172177
messages = []
173178
if input_data.history:
174179
for item in input_data.history:
@@ -284,6 +289,28 @@ async def generate(
284289
},
285290
)
286291

292+
# Top P check
293+
if params.top_p is not None:
294+
if params.top_p <= 0 or params.top_p > 1.0:
295+
return JSONResponse(
296+
status_code=400,
297+
content={
298+
"code": "InvalidParameter",
299+
"message": f"<400> InternalError.Algo.InvalidParameter: Range of top_p should be (0.0, 1.0], but got {params.top_p}",
300+
},
301+
)
302+
303+
# Temperature check
304+
if params.temperature is not None:
305+
if params.temperature < 0 or params.temperature > 2:
306+
return JSONResponse(
307+
status_code=400,
308+
content={
309+
"code": "InvalidParameter",
310+
"message": f"<400> InternalError.Algo.InvalidParameter: Temperature should be in [0, 2], but got {params.temperature}",
311+
},
312+
)
313+
287314
# Thinking Budget check
288315
if params.thinking_budget is not None:
289316
if params.thinking_budget <= 0:
@@ -388,6 +415,26 @@ async def generate(
388415
},
389416
)
390417

418+
if "deepseek-r1" in req_data.model and params.enable_thinking:
419+
return JSONResponse(
420+
status_code=400,
421+
content={
422+
"code": "InvalidParameter",
423+
"message": "Value error, current model does not support parameter `enable_thinking`.",
424+
},
425+
)
426+
427+
if params.enable_thinking:
428+
for msg in messages:
429+
if msg.get("partial"):
430+
return JSONResponse(
431+
status_code=400,
432+
content={
433+
"code": "InvalidParameter",
434+
"message": "<400> InternalError.Algo.InvalidParameter: Partial mode is not supported when enable_thinking is true",
435+
},
436+
)
437+
391438
# Stop parameter extraction
392439
proxy_stop_list: List[str] = []
393440
if params.stop:
@@ -526,6 +573,7 @@ async def generate(
526573

527574
# --- Execution ---
528575
try:
576+
is_r1_model = "deepseek-r1" in req_data.model
529577
if openai_params["stream"]:
530578
raw_resp = await self.client.chat.completions.with_raw_response.create(
531579
**openai_params
@@ -540,6 +588,7 @@ async def generate(
540588
trace_id,
541589
is_incremental=params.incremental_output,
542590
stop_sequences=proxy_stop_list,
591+
is_r1_model=is_r1_model,
543592
),
544593
media_type="text/event-stream",
545594
headers={"X-SiliconCloud-Trace-Id": trace_id},
@@ -552,7 +601,10 @@ async def generate(
552601
"X-SiliconCloud-Trace-Id", initial_request_id
553602
)
554603
return self._format_unary_response(
555-
raw_resp.parse(), trace_id, stop_sequences=proxy_stop_list
604+
raw_resp.parse(),
605+
trace_id,
606+
stop_sequences=proxy_stop_list,
607+
is_r1_model=is_r1_model,
556608
)
557609

558610
except APIError as e:
@@ -575,7 +627,12 @@ async def generate(
575627
)
576628

577629
async def _stream_generator(
578-
self, stream, request_id: str, is_incremental: bool, stop_sequences: List[str]
630+
self,
631+
stream,
632+
request_id: str,
633+
is_incremental: bool,
634+
stop_sequences: List[str],
635+
is_r1_model: bool = False,
579636
) -> AsyncGenerator[str, None]:
580637
accumulated_usage = {
581638
"total_tokens": 0,
@@ -612,6 +669,9 @@ async def _stream_generator(
612669
- accumulated_usage["output_tokens_details"]["reasoning_tokens"]
613670
)
614671

672+
if is_r1_model:
673+
accumulated_usage["output_tokens_details"].pop("text_tokens", None)
674+
615675
delta = chunk.choices[0].delta if chunk.choices else None
616676

617677
# --- Reasoning Content Handling ---
@@ -827,7 +887,11 @@ def _build_stream_response(
827887
return f"data: {json.dumps(response_body, ensure_ascii=False)}\n\n"
828888

829889
def _format_unary_response(
830-
self, completion, request_id: str, stop_sequences: List[str]
890+
self,
891+
completion,
892+
request_id: str,
893+
stop_sequences: List[str],
894+
is_r1_model: bool = False,
831895
):
832896
choice = completion.choices[0]
833897
msg = choice.message
@@ -861,6 +925,9 @@ def _format_unary_response(
861925
- usage_data["output_tokens_details"]["reasoning_tokens"]
862926
)
863927

928+
if is_r1_model:
929+
usage_data["output_tokens_details"].pop("text_tokens", None)
930+
864931
message_body = {
865932
"role": msg.role,
866933
"content": final_content, # Uses the potentially truncated content

0 commit comments

Comments
 (0)