Skip to content

Commit cac9ef7

Browse files
committed
fix
1 parent 5a47396 commit cac9ef7

File tree

2 files changed

+38
-154
lines changed

2 files changed

+38
-154
lines changed

tests/test_doc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# --- CONFIGURATION ---
88
BASE_URL_PREFIX = "http://localhost:8000/siliconflow/models"
9+
BASE_URL_PREFIX = "https://api-bailian.siliconflow.cn/siliconflow/models"
910
API_KEY = os.getenv("SILICONFLOW_API_KEY", "test_api_key")
1011

1112
HEADERS = {

tests/test_tools.py

Lines changed: 37 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import pytest
22
import json
33
import requests
4-
from typing import Generator, List, Dict, Any
5-
from dataclasses import dataclass
64
import os
5+
from typing import Generator, List, Dict, Any, Optional
6+
from dataclasses import dataclass
7+
8+
# --- 1. CONFIGURATION & CONSTANTS ---
79

8-
# --- CONSTANTS & CONFIGURATION ---
9-
# Note: Ensure your MockServer/Proxy is running on this port
1010
# GATEWAY_URL = "https://api-bailian.siliconflow.cn/api/v1/services/aigc/text-generation/generation"
1111
GATEWAY_URL = "http://localhost:8000/api/v1/services/aigc/text-generation/generation"
12-
# GATEWAY_URL = "http://localhost:8000/siliconflow/models/deepseek-ai/DeepSeek-R1"
13-
# GATEWAY_URL = "https://api-bailian.siliconflow.cn/siliconflow/models/deepseek-ai/DeepSeek-R1"
1412
API_KEY = os.getenv("SILICONFLOW_API_KEY", "test_api_key")
1513

16-
# Define the Tool Schema Vector
1714
TOOL_VECTOR_WEATHER = [
1815
{
1916
"type": "function",
@@ -35,16 +32,31 @@
3532
}
3633
]
3734

35+
# --- 2. CORE UTILITIES & DATA STRUCTURES ---
3836

3937
@dataclass
4038
class SSEFrame:
4139
"""Formal representation of a Server-Sent Event frame for validation."""
42-
4340
id: str
4441
output: Dict[str, Any]
4542
usage: Dict[str, Any]
4643
request_id: str
4744

45+
@property
46+
def text_content(self) -> str:
47+
"""Helper to safely extract standard content."""
48+
choices = self.output.get("choices", [])
49+
if not choices:
50+
return ""
51+
return choices[0].get("message", {}).get("content", "")
52+
53+
@property
54+
def reasoning_content(self) -> str:
55+
"""Helper to safely extract reasoning content (for R1 models)."""
56+
choices = self.output.get("choices", [])
57+
if not choices:
58+
return ""
59+
return choices[0].get("message", {}).get("reasoning_content", "")
4860

4961
def parse_sse_stream(response: requests.Response) -> Generator[SSEFrame, None, None]:
5062
"""
@@ -58,27 +70,25 @@ def parse_sse_stream(response: requests.Response) -> Generator[SSEFrame, None, N
5870
json_str = decoded_line[5:].strip()
5971
try:
6072
data = json.loads(json_str)
73+
# Handle cases where usage might be missing in some frames if strictly required
74+
usage_data = data.get("usage", {})
75+
6176
yield SSEFrame(
62-
id=data.get("output", {})
63-
.get("choices", [{}])[0]
64-
.get("id", "unknown"),
77+
id=data.get("output", {}).get("choices", [{}])[0].get("id", "unknown"),
6578
output=data.get("output", {}),
66-
usage=data.get("usage", {}),
79+
usage=usage_data,
6780
request_id=data.get("request_id", ""),
6881
)
6982
except json.JSONDecodeError:
7083
continue
7184

72-
7385
# --- SUITE A: INVARIANT & PREDICATE VERIFICATION ---
7486

75-
7687
def test_invariant_format_constraint():
7788
"""
7889
Predicate A: If P_tools is not empty, P_result_format must be 'message'.
7990
"""
8091
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
81-
8292
payload = {
8393
"model": "deepseek-v3",
8494
"input": {"messages": [{"role": "user", "content": "What's the weather?"}]},
@@ -87,26 +97,21 @@ def test_invariant_format_constraint():
8797
"result_format": "text", # <--- INTENTIONAL VIOLATION
8898
},
8999
}
90-
91100
response = requests.post(GATEWAY_URL, headers=headers, json=payload)
92101

93102
assert response.status_code == 400
94103
error_data = response.json()
95104
assert "code" in error_data
96105
assert "result_format" in str(error_data).lower()
97106

98-
99107
def test_invariant_r1_orthogonality():
100108
"""
101109
Predicate B: DeepSeek R1 'Thinking Mode' is orthogonal to 'Forced Tool Choice (Dict)'.
102110
"""
103111
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
104-
105112
payload = {
106113
"model": "deepseek-r1",
107-
"input": {
108-
"messages": [{"role": "user", "content": "Analyze the weather logic."}]
109-
},
114+
"input": {"messages": [{"role": "user", "content": "Analyze the weather logic."}]},
110115
"parameters": {
111116
"enable_thinking": True,
112117
"tools": TOOL_VECTOR_WEATHER,
@@ -116,16 +121,13 @@ def test_invariant_r1_orthogonality():
116121
},
117122
},
118123
}
119-
120124
response = requests.post(GATEWAY_URL, headers=headers, json=payload)
121125

122126
assert response.status_code == 400
123127
assert "InvalidParameter" in response.json().get("code", "")
124128

125-
126129
# --- SUITE B: PROTOCOL ISOMORPHISM (SSE TELEMETRY) ---
127130

128-
129131
def test_telemetry_continuity_sse():
130132
"""
131133
Theorem: The 'usage' object must be persisted in EVERY SSE frame.
@@ -135,7 +137,6 @@ def test_telemetry_continuity_sse():
135137
"Content-Type": "application/json",
136138
"X-DashScope-SSE": "enable",
137139
}
138-
139140
payload = {
140141
"model": "deepseek-v3",
141142
"input": {"messages": [{"role": "user", "content": "Call the tool."}]},
@@ -157,14 +158,11 @@ def test_telemetry_continuity_sse():
157158

158159
assert frame_count > 0
159160

160-
161161
# --- SUITE C: TOOL INVOCATION & CONFIGURATION TESTS ---
162162

163-
164163
def test_unary_tool_invocation_structure():
165164
"""
166-
Validates that standard unary (non-streaming) responses maintain proper tool structures
167-
when tools are enabled but not necessarily forced.
165+
Validates standard unary responses maintain tool structures when tools enabled.
168166
"""
169167
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
170168
payload = {
@@ -185,16 +183,13 @@ def test_unary_tool_invocation_structure():
185183
assert "choices" in data["output"]
186184
assert len(data["output"]["choices"]) > 0
187185

188-
# Check structure integrity
189186
choice = data["output"]["choices"][0]
190187
assert "message" in choice
191188
assert "usage" in data
192189

193-
194190
def test_tool_choice_none_suppression():
195191
"""
196-
Validates that tool_choice='none' is accepted by the gateway and processed without error.
197-
This ensures the explicit suppression logic path is valid.
192+
Validates that tool_choice='none' is accepted and processes without error.
198193
"""
199194
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
200195
payload = {
@@ -209,120 +204,31 @@ def test_tool_choice_none_suppression():
209204
response = requests.post(GATEWAY_URL, headers=headers, json=payload)
210205
assert response.status_code == 200
211206

212-
# Even if mocked, the structure must be valid
213207
data = response.json()
214208
assert data["output"]["choices"][0]["finish_reason"] is not None
215209

216-
217-
import pytest
218-
import json
219-
import requests
220-
from typing import Generator, List, Dict, Any, Optional
221-
from dataclasses import dataclass
222-
import os
223-
224-
# --- CONSTANTS & CONFIGURATION ---
225-
GATEWAY_URL = "http://localhost:8000/api/v1/services/aigc/text-generation/generation"
226-
API_KEY = os.getenv("SILICONFLOW_API_KEY", "test_api_key")
227-
228-
TOOL_VECTOR_WEATHER = [
229-
{
230-
"type": "function",
231-
"function": {
232-
"name": "get_current_weather",
233-
"description": "Get the current weather in a given location",
234-
"parameters": {
235-
"type": "object",
236-
"properties": {
237-
"location": {"type": "string"},
238-
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
239-
},
240-
"required": ["location"],
241-
},
242-
},
243-
}
244-
]
245-
246-
247-
@dataclass
248-
class SSEFrame:
249-
id: str
250-
output: Dict[str, Any]
251-
usage: Dict[str, Any]
252-
request_id: str
253-
254-
@property
255-
def text_content(self) -> str:
256-
"""Helper to safely extract standard content."""
257-
choices = self.output.get("choices", [])
258-
if not choices:
259-
return ""
260-
return choices[0].get("message", {}).get("content", "")
261-
262-
@property
263-
def reasoning_content(self) -> str:
264-
"""Helper to safely extract reasoning content (for R1 models)."""
265-
choices = self.output.get("choices", [])
266-
if not choices:
267-
return ""
268-
return choices[0].get("message", {}).get("reasoning_content", "")
269-
270-
271-
def parse_sse_stream(response: requests.Response) -> Generator[SSEFrame, None, None]:
272-
for line in response.iter_lines():
273-
if line:
274-
decoded_line = line.decode("utf-8")
275-
if decoded_line.startswith("data:"):
276-
json_str = decoded_line[5:].strip()
277-
try:
278-
data = json.loads(json_str)
279-
yield SSEFrame(
280-
id=data.get("output", {})
281-
.get("choices", [{}])[0]
282-
.get("id", "unknown"),
283-
output=data.get("output", {}),
284-
usage=data.get("usage", {}),
285-
request_id=data.get("request_id", ""),
286-
)
287-
except json.JSONDecodeError:
288-
continue
289-
290-
291-
# --- EXISTING SUITES (A, B, C) OMITTED FOR BREVITY ---
292-
# (Assume Suite A, B, and C from your original code are here)
293-
294210
# --- SUITE D: INCREMENTAL OUTPUT BEHAVIOR ---
295211

296-
297212
def assert_stream_accumulation(frames: List[SSEFrame], check_reasoning: bool = False):
298213
"""
299214
Validates 'Accumulated' behavior (incremental_output=False).
300215
Theorem: For any frame N > 0, Content(N) must start with Content(N-1).
301216
"""
302217
previous_content = ""
303-
304-
# Filter out empty initial frames if necessary, though accumulation should start immediately
305218
for i, frame in enumerate(frames):
306-
# Select content type based on what the model is outputting (Reasoning vs Standard)
307-
current_content = (
308-
frame.reasoning_content if check_reasoning else frame.text_content
309-
)
219+
current_content = frame.reasoning_content if check_reasoning else frame.text_content
310220

311-
# Skip empty frames (sometimes initial frames are just metadata)
221+
# Skip empty frames
312222
if not current_content and not previous_content:
313223
continue
314224

315-
# Assertion: The new content must contain the previous content as a prefix
316-
# This proves the server is sending the full history every time.
317225
assert current_content.startswith(previous_content), (
318226
f"Frame {i} violation: Output is not accumulated.\n"
319227
f"Previous: {previous_content!r}\n"
320228
f"Current: {current_content!r}"
321229
)
322-
323230
previous_content = current_content
324231

325-
326232
def assert_stream_deltas(frames: List[SSEFrame], check_reasoning: bool = False):
327233
"""
328234
Validates 'Delta' behavior (incremental_output=True).
@@ -332,31 +238,20 @@ def assert_stream_deltas(frames: List[SSEFrame], check_reasoning: bool = False):
332238
previous_content = ""
333239

334240
for i, frame in enumerate(frames):
335-
current_content = (
336-
frame.reasoning_content if check_reasoning else frame.text_content
337-
)
338-
241+
current_content = frame.reasoning_content if check_reasoning else frame.text_content
339242
if not current_content:
340243
continue
341244

342-
# If current content strictly starts with previous content AND adds to it,
343-
# we might be in accumulation mode.
344-
# Note: We need a heuristic because "The" -> "The cat" is technically accumulation,
345-
# but in Delta mode it should be "The" -> " cat".
346-
if (
347-
previous_content
348-
and current_content.startswith(previous_content)
349-
and len(current_content) > len(previous_content)
350-
):
245+
# Heuristic: If content strictly grows and contains previous, it's likely accumulation
246+
if (previous_content and
247+
current_content.startswith(previous_content) and
248+
len(current_content) > len(previous_content)):
351249
accumulation_detected = True
352250
break
353251

354252
previous_content = current_content
355253

356-
assert (
357-
not accumulation_detected
358-
), "Stream appears to be accumulating full text, expected Deltas (chunks) only."
359-
254+
assert not accumulation_detected, "Stream appears to be accumulating full text, expected Deltas."
360255

361256
def test_incremental_output_false_explicit():
362257
"""
@@ -369,7 +264,7 @@ def test_incremental_output_false_explicit():
369264
"X-DashScope-SSE": "enable",
370265
}
371266
payload = {
372-
"model": "deepseek-r1", # Using R1 to check reasoning accumulation
267+
"model": "deepseek-r1",
373268
"input": {"messages": [{"role": "user", "content": "Count to 5"}]},
374269
"parameters": {
375270
"result_format": "message",
@@ -379,14 +274,10 @@ def test_incremental_output_false_explicit():
379274

380275
response = requests.post(GATEWAY_URL, headers=headers, json=payload, stream=True)
381276
assert response.status_code == 200
382-
383277
frames = list(parse_sse_stream(response))
384278
assert len(frames) > 0
385-
386-
# Verify accumulation on reasoning_content (since R1 outputs reasoning first)
387279
assert_stream_accumulation(frames, check_reasoning=True)
388280

389-
390281
def test_incremental_output_default_behavior():
391282
"""
392283
Case 2: incremental_output param is MISSING.
@@ -408,19 +299,14 @@ def test_incremental_output_default_behavior():
408299

409300
response = requests.post(GATEWAY_URL, headers=headers, json=payload, stream=True)
410301
assert response.status_code == 200
411-
412302
frames = list(parse_sse_stream(response))
413303
assert len(frames) > 0
414-
415-
# Verify the default behavior matches accumulation
416304
assert_stream_accumulation(frames, check_reasoning=True)
417305

418-
419306
def test_incremental_output_true_contrast():
420307
"""
421308
Case 3: Explicitly set incremental_output=True.
422309
Expectation: The response contains only DELTAS (chunks).
423-
Used to ensure the switch actually works.
424310
"""
425311
headers = {
426312
"Authorization": f"Bearer {API_KEY}",
@@ -438,9 +324,6 @@ def test_incremental_output_true_contrast():
438324

439325
response = requests.post(GATEWAY_URL, headers=headers, json=payload, stream=True)
440326
assert response.status_code == 200
441-
442327
frames = list(parse_sse_stream(response))
443328
assert len(frames) > 0
444-
445-
# Verify we are receiving deltas, not full text
446329
assert_stream_deltas(frames, check_reasoning=True)

0 commit comments

Comments
 (0)