-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathproxy-classifier.py
More file actions
127 lines (108 loc) · 4.1 KB
/
proxy-classifier.py
File metadata and controls
127 lines (108 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import requests
import json
import time
import os
import uuid
from typing import List, Optional
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException, Header
from create_mxbai_v2_reranker_prompt_template import create_mxbai_v2_reranker_prompt_template as templating
inference_host = os.environ.get("INFERENCE_HOST", "localhost")
app = FastAPI(title="Rerank proxy for Infinity Classifier")
# Define request/response models
class RerankRequest(BaseModel):
query: str
documents: List[str]
model: str = "michaelfeil/mxbai-rerank-large-v2-seq"
top_n: Optional[int] = None
return_documents: bool = False
max_chunks_per_doc: Optional[int] = None
class RerankResult(BaseModel):
relevance_score: float
index: int
document: Optional[str] = None
class RerankUsage(BaseModel):
prompt_tokens: int
total_tokens: int
class RerankResponse(BaseModel):
object: str = "rerank"
results: List[RerankResult]
model: str
usage: RerankUsage
id: str
created: int
@app.post("/v1/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest, authorization: Optional[str] = Header(None)):
# Prepare inputs for the classifier endpoint
inputs = [
templating(request.query, doc)
for doc in request.documents
]
# Create payload for the classification API
payload = {
"input": inputs,
"model": request.model,
"raw_scores": True
}
# Set headers and forward the incoming Authorization header if provided
headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
if authorization:
headers["Authorization"] = authorization
try:
# Call classification endpoint
response = requests.post(
f"http://{inference_host}:7997/v1/classify",
headers=headers,
data=json.dumps(payload)
)
response.raise_for_status()
response_data = response.json()
# Extract scores from the response structure
scores = []
if "data" in response_data:
for item in response_data["data"]:
# Get the score for the "1"/relevant label
relevant_score = next(
(x["score"] for x in item if x["label"] == "1"),
max(x["score"] for x in item)
)
scores.append(relevant_score)
else:
raise HTTPException(status_code=500, detail="Unexpected response format from classification API")
# Create list of results with index and scores
formatted_results = [
{
"relevance_score": score,
"index": idx,
"document": doc if request.return_documents else None
}
for idx, (score, doc) in enumerate(zip(scores, request.documents))
]
# Sort results by score in descending order
formatted_results.sort(key=lambda x: x["relevance_score"], reverse=True)
# Apply top_n filter if specified
if request.top_n is not None and request.top_n > 0:
formatted_results = formatted_results[:request.top_n]
# Get token counts from response or estimate otherwise
total_tokens = response_data.get("usage", {}).get("prompt_tokens", len(''.join(inputs)))
# Create final response
final_response = {
"object": "rerank",
"results": formatted_results,
"model": response_data.get("model", request.model),
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens
},
"id": response_data.get("id", f"rerank-{uuid.uuid4()}"),
"created": response_data.get("created", int(time.time()))
}
return final_response
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"Error calling classification API: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("proxy-classifier:app", host="0.0.0.0", port=8002, reload=True)