Skip to content

Commit d9f7e21

Browse files
committed
Update mock_server.py
1 parent fce5dc0 commit d9f7e21

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

tests/mock_server.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fastapi.responses import StreamingResponse, JSONResponse
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from pydantic import BaseModel, Field
16-
from openai import OpenAI, APIError, RateLimitError, AuthenticationError
16+
from openai import AsyncOpenAI, APIError, RateLimitError, AuthenticationError
1717

1818
# --- [System Configuration] ---
1919

@@ -112,8 +112,8 @@ class GenerationRequest(BaseModel):
112112

113113
class DeepSeekProxy:
114114
def __init__(self):
115-
self.client = OpenAI(
116-
api_key=SILICON_FLOW_API_KEY if SILICON_FLOW_API_KEY else None,
115+
self.client = AsyncOpenAI(
116+
api_key=SILICON_FLOW_API_KEY if SILICON_FLOW_API_KEY else "dummy_key",
117117
base_url=SILICON_FLOW_BASE_URL
118118
)
119119

@@ -174,15 +174,17 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
174174

175175
try:
176176
if openai_params["stream"]:
177-
# Fetch raw response for headers in stream mode
178-
raw_resp = self.client.chat.completions.with_raw_response.create(**openai_params)
177+
# Fetch raw response for headers in stream mode (awaited)
178+
raw_resp = await self.client.chat.completions.with_raw_response.create(**openai_params)
179179
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
180+
# raw_resp.parse() returns the AsyncStream
180181
return StreamingResponse(
181182
self._stream_generator(raw_resp.parse(), trace_id),
182183
media_type="text/event-stream"
183184
)
184185
else:
185-
raw_resp = self.client.chat.completions.with_raw_response.create(**openai_params)
186+
# Standard response (awaited)
187+
raw_resp = await self.client.chat.completions.with_raw_response.create(**openai_params)
186188
trace_id = raw_resp.headers.get("X-SiliconCloud-Trace-Id", initial_request_id)
187189
return self._format_unary_response(raw_resp.parse(), trace_id)
188190

@@ -204,7 +206,7 @@ async def _stream_generator(self, stream, request_id: str) -> AsyncGenerator[str
204206
}
205207
finish_reason = "null"
206208

207-
for chunk in stream:
209+
async for chunk in stream:
208210
if chunk.usage:
209211
accumulated_usage["total_tokens"] = chunk.usage.total_tokens
210212
accumulated_usage["input_tokens"] = chunk.usage.prompt_tokens
@@ -289,7 +291,6 @@ def create_app() -> FastAPI:
289291
app.add_middleware(
290292
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
291293
)
292-
proxy = DeepSeekProxy()
293294

294295
@app.middleware("http")
295296
async def request_tracker(request: Request, call_next):
@@ -319,6 +320,9 @@ async def health_check():
319320
async def generation(request: Request, body: GenerationRequest = None):
320321
request_id = request.headers.get("x-request-id", str(uuid.uuid4()))
321322

323+
# Instantiate Proxy Per Request
324+
proxy = DeepSeekProxy()
325+
322326
if not body:
323327
try:
324328
raw_json = await request.json()
@@ -331,6 +335,7 @@ async def generation(request: Request, body: GenerationRequest = None):
331335
if body:
332336
logger.info(f"[Shadow] Validating request against upstream...")
333337
try:
338+
# Async generate call on the local instance
334339
await proxy.generate(body, f"shadow-{request_id}")
335340
except Exception as e:
336341
logger.error(f"[Shadow] Validation Exception: {str(e)}")

0 commit comments

Comments
 (0)