Skip to content

Commit 9f44128

Browse files
committed
首次提交
0 parents  commit 9f44128

File tree

6 files changed

+518
-0
lines changed

6 files changed

+518
-0
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Python-generated files
2+
__pycache__/
3+
*.py[oc]
4+
build/
5+
dist/
6+
wheels/
7+
*.egg-info
8+
9+
# Virtual environments
10+
.venv

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

README.md

Whitespace-only changes.

openai_router/main.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from fastapi import FastAPI, HTTPException, Request
2+
from fastapi.responses import Response, StreamingResponse
3+
import httpx
4+
from loguru import logger
5+
import os
6+
from contextlib import asynccontextmanager
7+
8+
9+
backend_servers = {}
10+
# 创建一个可重用的 httpx 客户端
11+
client: httpx.AsyncClient = None
12+
13+
14+
@asynccontextmanager
15+
async def lifespan(app: FastAPI):
16+
global client
17+
# 启动时的逻辑
18+
models = os.environ.get("MODELS", "").split(",")
19+
for model in models:
20+
model = model.strip()
21+
if model:
22+
try:
23+
model_name, model_url = model.split("=")
24+
backend_servers[model_name] = model_url
25+
except ValueError:
26+
logger.warning(f"Skipping misformatted model entry: {model}")
27+
28+
logger.info(f"Backend servers: {backend_servers}")
29+
30+
# 初始化 httpx 客户端,设置一个合理的超时
31+
# read=None 意味着对读取操作不设置超时,这对于流式响应是必要的
32+
timeout = httpx.Timeout(10.0, connect=60.0, read=None, write=60.0)
33+
client = httpx.AsyncClient(timeout=timeout)
34+
35+
yield # 应用运行期间
36+
37+
# 关闭时的逻辑
38+
if client:
39+
await client.aclose()
40+
41+
42+
# 使用 lifespan 创建 FastAPI 实例
43+
app = FastAPI(lifespan=lifespan)
44+
45+
46+
async def _get_routing_info(request: Request):
47+
"""
48+
辅助函数:解析请求体以获取模型和目标后端 URL。
49+
这是路由逻辑所必需的。
50+
"""
51+
try:
52+
json_body = await request.json()
53+
except Exception as e:
54+
logger.error(f"Failed to parse request body: {e}")
55+
raise HTTPException(status_code=400, detail="Invalid JSON body")
56+
57+
model = json_body.get("model")
58+
if model is None:
59+
raise HTTPException(
60+
status_code=400, detail="'model' field is required in request body"
61+
)
62+
63+
server = backend_servers.get(model)
64+
if server is None:
65+
raise HTTPException(
66+
status_code=400,
67+
detail=f"Invalid model: {model}. Available models: {list(backend_servers.keys())}",
68+
)
69+
70+
backend_url = server + request.url.path
71+
logger.info(f"Routing to backend_url: {backend_url} for model {model}")
72+
73+
return backend_url, json_body
74+
75+
76+
async def _stream_proxy(backend_url: str, request: Request, json_body: dict):
77+
"""
78+
这是一个异步生成器,用于代理流式响应。
79+
"""
80+
# 准备转发给后端的请求头
81+
# 移除 'host' 和 'content-length',因为它们将由 httpx 重新计算
82+
headers = {
83+
h: v
84+
for h, v in request.headers.items()
85+
if h.lower() not in ["host", "content-length"]
86+
}
87+
88+
try:
89+
# 使用 client.stream 发起请求
90+
async with client.stream(
91+
request.method,
92+
backend_url,
93+
params=request.query_params,
94+
json=json_body, # 我们已经读取了 body,所以作为 json 参数传递
95+
headers=headers,
96+
) as response:
97+
98+
# 在开始流式传输之前,检查后端的错误响应
99+
# 我们不能在 StreamingResponse 中途设置状态码,但我们可以选择不流式传输
100+
if response.status_code >= 400:
101+
# 如果后端出错,读取错误信息并作为 HTTP 异常抛出
102+
error_content = await response.aread()
103+
logger.warning(
104+
f"Backend error: {response.status_code} - {error_content.decode()}"
105+
)
106+
raise HTTPException(
107+
status_code=response.status_code, detail=error_content.decode()
108+
)
109+
110+
# 迭代来自后端的流式数据块
111+
async for chunk in response.aiter_bytes():
112+
# 将每个数据块 yield 给 FastAPI
113+
yield chunk
114+
115+
except httpx.ConnectError as e:
116+
logger.error(f"Connection error to backend {backend_url}: {e}")
117+
raise HTTPException(status_code=503, detail="Backend service unavailable")
118+
except Exception as e:
119+
logger.error(f"An error occurred during streaming proxy: {e}")
120+
# 此时可能已经发送了部分响应,所以我们不能再抛出 HTTPException
121+
# 只能记录错误并停止
122+
logger.error("Streaming interrupted due to an error.")
123+
124+
125+
async def _non_stream_proxy(backend_url: str, request: Request, json_body: dict):
126+
"""
127+
处理非流式请求的代理逻辑。
128+
"""
129+
headers = {
130+
h: v
131+
for h, v in request.headers.items()
132+
if h.lower() not in ["host", "content-length"]
133+
}
134+
135+
try:
136+
response = await client.post(
137+
backend_url, params=request.query_params, json=json_body, headers=headers
138+
)
139+
140+
# 转发后端的响应,包括错误
141+
return Response(
142+
content=response.content,
143+
media_type=response.headers.get("Content-Type"),
144+
status_code=response.status_code,
145+
)
146+
147+
except httpx.ConnectError as e:
148+
logger.error(f"Connection error to backend {backend_url}: {e}")
149+
raise HTTPException(status_code=503, detail="Backend service unavailable")
150+
except httpx.ReadTimeout as e:
151+
logger.error(f"Read timeout from backend {backend_url}: {e}")
152+
raise HTTPException(status_code=504, detail="Backend request timed out")
153+
except Exception as e:
154+
logger.error(f"An error occurred during non-streaming proxy: {e}")
155+
raise HTTPException(status_code=500, detail=f"Internal proxy error: {e}")
156+
157+
158+
@app.post("/v1/completions", summary="/v1/completions")
159+
@app.post("/v1/chat/completions", summary="/v1/chat/completions")
160+
async def post_completions(request: Request):
161+
backend_url, json_body = await _get_routing_info(request)
162+
163+
if json_body.get("stream", False):
164+
logger.info("Handling as STREAMING request")
165+
return StreamingResponse(
166+
_stream_proxy(backend_url, request, json_body),
167+
media_type="text/event-stream",
168+
)
169+
else:
170+
logger.info("Handling as NON-STREAMING request")
171+
return await _non_stream_proxy(backend_url, request, json_body)
172+
173+
174+
if __name__ == "__main__":
175+
import uvicorn
176+
177+
# 你需要通过环境变量来设置模型,例如:
178+
# MODELS="gpt-4=http://localhost:8080,llama=http://localhost:8081" uvicorn streaming_proxy:app --host 0.0.0.0 --port 8000
179+
os.environ["MODELS"] = "qwen3=https://miyun.archermind.com"
180+
if not os.environ.get("MODELS"):
181+
logger.warning(
182+
"MODELS environment variable is not set. Example: MODELS='model_name=http://backend_url'"
183+
)
184+
uvicorn.run(app, host="0.0.0.0", port=8000)

pyproject.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[project]
2+
name = "openai-router"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
dependencies = [
8+
"fastapi>=0.120.3",
9+
"httpx>=0.28.1",
10+
"loguru>=0.7.3",
11+
"uvicorn>=0.38.0",
12+
]
13+
[[tool.uv.index]]
14+
url = "https://pypi.org/simple/"
15+
default = true

0 commit comments

Comments
 (0)