Skip to content

Commit 8ab0c0f

Browse files
committed
add
1 parent f4e4b93 commit 8ab0c0f

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

tests/test_doc.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# ... (保留上面的 Server 代码不变) ...
2+
3+
# ==========================================
4+
# UPDATED TEST SUITE (Dynamic URL)
5+
# ==========================================
6+
17
import pytest
28
import json
39
import requests
@@ -6,17 +12,17 @@
612
import os
713

814
# --- CONSTANTS & CONFIGURATION ---
9-
GATEWAY_URL = "http://localhost:8000/api/v1/services/aigc/text-generation/generation"
10-
# 如果需要测试硅基流动的真实环境,请切换 URL 并设置 API KEY
11-
# GATEWAY_URL = "https://api-bailian.siliconflow.cn/api/v1/services/aigc/text-generation/generation"
15+
16+
# 修改为动态 URL 的基础路径
17+
GATEWAY_BASE_URL = "http://localhost:8000/siliconflow/models"
1218

1319
API_KEY = os.getenv("SILICONFLOW_API_KEY", "test_api_key")
1420

1521
HEADERS = {
1622
"Authorization": f"Bearer {API_KEY}",
1723
"Content-Type": "application/json",
1824
"Accept": "text/event-stream",
19-
"X-DashScope-SSE": "enable", # 模拟 DashScope 协议头
25+
"X-DashScope-SSE": "enable",
2026
}
2127

2228
# --- TOOL DEFINITIONS ---
@@ -43,17 +49,14 @@
4349

4450
# --- HELPERS ---
4551

46-
4752
@dataclass
4853
class SSEFrame:
4954
"""Formal representation of a Server-Sent Event frame for validation."""
50-
5155
id: str
5256
output: Dict[str, Any]
5357
usage: Dict[str, Any]
5458
request_id: str
5559

56-
5760
def parse_sse_stream(response: requests.Response) -> Generator[SSEFrame, None, None]:
5861
"""Parses the raw SSE stream."""
5962
for line in response.iter_lines():
@@ -64,25 +67,36 @@ def parse_sse_stream(response: requests.Response) -> Generator[SSEFrame, None, N
6467
try:
6568
data = json.loads(json_str)
6669
yield SSEFrame(
67-
id=data.get("output", {})
68-
.get("choices", [{}])[0]
69-
.get("id", "unknown"),
70+
id=data.get("output", {}).get("choices", [{}])[0].get("id", "unknown"),
7071
output=data.get("output", {}),
7172
usage=data.get("usage", {}),
7273
request_id=data.get("request_id", ""),
7374
)
7475
except json.JSONDecodeError:
7576
continue
7677

78+
def make_request(payload: Dict[str, Any]):
79+
"""
80+
Helper to send POST request using the Dynamic Path URL.
81+
Extracts 'model' from payload to construct the URL:
82+
POST /siliconflow/models/{model_path}
83+
"""
84+
# 提取模型名称用于构建 URL
85+
model_path = payload.get("model")
7786

78-
def make_request(payload):
79-
"""Helper to send POST request."""
80-
return requests.post(GATEWAY_URL, headers=HEADERS, json=payload, stream=True)
87+
if not model_path:
88+
raise ValueError("Payload must contain 'model' field for dynamic URL construction")
8189

90+
# 构建动态 URL
91+
# 例如: http://localhost:8000/siliconflow/models/pre-siliconflow/deepseek-v3
92+
url = f"{GATEWAY_BASE_URL}/{model_path}"
8293

83-
# --- TEST SUITE ---
94+
# 发送请求 (Payload 中保留 model 字段通常没问题,服务器端代码会再次处理或忽略)
95+
return requests.post(url, headers=HEADERS, json=payload, stream=True)
8496

8597

98+
# --- TEST SUITE ---
99+
86100
class TestParameterValidation:
87101
"""
88102
对应表格中参数校验相关的错误用例 (4xx Error Codes)
@@ -91,7 +105,7 @@ class TestParameterValidation:
91105
def test_invalid_parameter_type_top_p(self):
92106
"""
93107
Case: parameters.top_p 输入字符串 'a',预期返回 400 InvalidParameter。
94-
Bug描述: 曾返回 InternalError 500。
108+
URL: /siliconflow/models/pre-siliconflow/deepseek-v3
95109
"""
96110
payload = {
97111
"model": "pre-siliconflow/deepseek-v3",
@@ -101,21 +115,18 @@ def test_invalid_parameter_type_top_p(self):
101115
response = make_request(payload)
102116

103117
# 验证状态码不应为 500
104-
assert (
105-
response.status_code != 500
106-
), "Should not return 500 for invalid parameter type"
118+
assert response.status_code != 500, "Should not return 500 for invalid parameter type"
107119
assert response.status_code == 400
108120

109121
data = response.json()
110-
assert "InvalidParameter" in data.get(
111-
"code", ""
112-
) or "InvalidParameter" in data.get("message", "")
122+
assert "InvalidParameter" in data.get("code", "") or "InvalidParameter" in data.get("message", "")
113123

114124
@pytest.mark.parametrize("top_p_value", [0, 0.0])
115125
def test_invalid_parameter_range_top_p(self, top_p_value):
116126
"""
117127
Case: pre-siliconflow-deepseek-v3.1 top_p取值范围 (0, 1.0]。
118128
测试边界值 0,预期报错。
129+
URL: /siliconflow/models/pre-siliconflow/deepseek-v3.1
119130
"""
120131
payload = {
121132
"model": "pre-siliconflow/deepseek-v3.1",
@@ -154,6 +165,7 @@ def test_r1_usage_structure(self):
154165
"""
155166
Case: .usage.output_tokens_details 该路径下不应该返回 text_tokens 字段。
156167
R1 模型推理侧可能没有 text_tokens。
168+
URL: /siliconflow/models/pre-siliconflow/deepseek-r1
157169
"""
158170
payload = {
159171
"model": "pre-siliconflow/deepseek-r1",
@@ -172,9 +184,7 @@ def test_r1_usage_structure(self):
172184
# 验证 output_tokens_details 存在
173185
assert output_details, "output_tokens_details missing"
174186
# 验证不包含 text_tokens (根据表格描述这是预期行为)
175-
assert (
176-
"text_tokens" not in output_details
177-
), "R1 usage should not contain text_tokens"
187+
assert "text_tokens" not in output_details, "R1 usage should not contain text_tokens"
178188
# 验证包含 reasoning_tokens
179189
assert "reasoning_tokens" in output_details
180190

@@ -204,6 +214,7 @@ def test_prefix_completion_thinking_conflict(self):
204214
"""
205215
Case: 思考模式下(enable_thinking=true),不支持前缀续写(partial=true)。
206216
预期返回: 400 InvalidParameter.
217+
URL: /siliconflow/models/pre-siliconflow/deepseek-v3.2
207218
"""
208219
payload = {
209220
"model": "pre-siliconflow/deepseek-v3.2",
@@ -219,9 +230,7 @@ def test_prefix_completion_thinking_conflict(self):
219230

220231
assert response.status_code == 400
221232
data = response.json()
222-
assert "Partial mode is not supported when enable_thinking is true" in data.get(
223-
"message", ""
224-
)
233+
assert "Partial mode is not supported when enable_thinking is true" in data.get("message", "")
225234

226235
def test_history_with_tool_calls(self):
227236
"""
@@ -262,9 +271,7 @@ def test_history_with_tool_calls(self):
262271
response = make_request(payload)
263272

264273
# 核心验证:不能崩 (500)
265-
assert (
266-
response.status_code != 500
267-
), "Server returned 500 for history with tool calls"
274+
assert response.status_code != 500, "Server returned 500 for history with tool calls"
268275
assert response.status_code == 200
269276

270277
def test_r1_tool_call_format_wrapping(self):
@@ -285,18 +292,14 @@ def test_r1_tool_call_format_wrapping(self):
285292
"tool_choice": {
286293
"type": "function",
287294
"function": {"name": "get_current_weather"},
288-
}, # 修正后的 tool_choice 格式
295+
},
289296
"tools": TOOL_VECTOR_WEATHER,
290297
},
291298
}
292299

293-
# 注意:CSV中提到的错误是 `tool_choice` 格式问题导致的 400 被包了一层 500
294-
# 这里我们发送请求并检查状态码
295300
response = make_request(payload)
296301

297-
# 即使失败,也应该返回标准的 400 而不是 InternalError
298302
if response.status_code != 200:
299303
error_data = response.json()
300-
# 确保不是 500 或者 InternalError
301304
assert response.status_code != 500
302305
assert error_data.get("code") != "InternalError"

0 commit comments

Comments
 (0)