Skip to content

Commit f982534

Browse files
authored
feat: support /v1/models in direct response (#283)
Signed-off-by: bitliu <[email protected]>
1 parent 2d4d5ae commit f982534

File tree

2 files changed

+370
-0
lines changed

2 files changed

+370
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package extproc
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
8+
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
9+
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
10+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
11+
)
12+
13+
func TestHandleModelsRequest(t *testing.T) {
14+
// Create a test router with mock config
15+
cfg := &config.RouterConfig{
16+
VLLMEndpoints: []config.VLLMEndpoint{
17+
{
18+
Name: "primary",
19+
Address: "127.0.0.1",
20+
Port: 8000,
21+
Models: []string{"gpt-4o-mini", "llama-3.1-8b-instruct"},
22+
Weight: 1,
23+
},
24+
},
25+
}
26+
27+
router := &OpenAIRouter{
28+
Config: cfg,
29+
}
30+
31+
tests := []struct {
32+
name string
33+
path string
34+
expectedModels []string
35+
expectedCount int
36+
}{
37+
{
38+
name: "GET /v1/models - all models",
39+
path: "/v1/models",
40+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
41+
expectedCount: 3,
42+
},
43+
{
44+
name: "GET /v1/models?model=auto - all models (no filtering implemented)",
45+
path: "/v1/models?model=auto",
46+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
47+
expectedCount: 3,
48+
},
49+
{
50+
name: "GET /v1/models?model=gpt-4o-mini - all models (no filtering)",
51+
path: "/v1/models?model=gpt-4o-mini",
52+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
53+
expectedCount: 3,
54+
},
55+
{
56+
name: "GET /v1/models?model= - all models (empty param)",
57+
path: "/v1/models?model=",
58+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
59+
expectedCount: 3,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
response, err := router.handleModelsRequest(tt.path)
66+
if err != nil {
67+
t.Fatalf("handleModelsRequest failed: %v", err)
68+
}
69+
70+
// Verify it's an immediate response
71+
immediateResp := response.GetImmediateResponse()
72+
if immediateResp == nil {
73+
t.Fatal("Expected immediate response, got nil")
74+
}
75+
76+
// Verify status code is 200 OK
77+
if immediateResp.Status.Code != typev3.StatusCode_OK {
78+
t.Errorf("Expected status code OK, got %v", immediateResp.Status.Code)
79+
}
80+
81+
// Verify content-type header
82+
found := false
83+
for _, header := range immediateResp.Headers.SetHeaders {
84+
if header.Header.Key == "content-type" {
85+
if string(header.Header.RawValue) != "application/json" {
86+
t.Errorf("Expected content-type application/json, got %s", string(header.Header.RawValue))
87+
}
88+
found = true
89+
break
90+
}
91+
}
92+
if !found {
93+
t.Error("Expected content-type header not found")
94+
}
95+
96+
// Parse response body
97+
var modelList OpenAIModelList
98+
if err := json.Unmarshal(immediateResp.Body, &modelList); err != nil {
99+
t.Fatalf("Failed to parse response body: %v", err)
100+
}
101+
102+
// Verify response structure
103+
if modelList.Object != "list" {
104+
t.Errorf("Expected object 'list', got %s", modelList.Object)
105+
}
106+
107+
if len(modelList.Data) != tt.expectedCount {
108+
t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList.Data))
109+
}
110+
111+
// Verify expected models are present
112+
modelMap := make(map[string]bool)
113+
for _, model := range modelList.Data {
114+
modelMap[model.ID] = true
115+
116+
// Verify model structure
117+
if model.Object != "model" {
118+
t.Errorf("Expected model object 'model', got %s", model.Object)
119+
}
120+
if model.Created == 0 {
121+
t.Error("Expected non-zero created timestamp")
122+
}
123+
if model.OwnedBy != "vllm-semantic-router" {
124+
t.Errorf("Expected model owned_by 'vllm-semantic-router', got %s", model.OwnedBy)
125+
}
126+
}
127+
128+
for _, expectedModel := range tt.expectedModels {
129+
if !modelMap[expectedModel] {
130+
t.Errorf("Expected model %s not found in response", expectedModel)
131+
}
132+
}
133+
})
134+
}
135+
}
136+
137+
func TestHandleRequestHeadersWithModelsEndpoint(t *testing.T) {
138+
// Create a test router
139+
cfg := &config.RouterConfig{
140+
VLLMEndpoints: []config.VLLMEndpoint{
141+
{
142+
Name: "primary",
143+
Address: "127.0.0.1",
144+
Port: 8000,
145+
Models: []string{"gpt-4o-mini"},
146+
Weight: 1,
147+
},
148+
},
149+
}
150+
151+
router := &OpenAIRouter{
152+
Config: cfg,
153+
}
154+
155+
tests := []struct {
156+
name string
157+
method string
158+
path string
159+
expectImmediate bool
160+
}{
161+
{
162+
name: "GET /v1/models - should return immediate response",
163+
method: "GET",
164+
path: "/v1/models",
165+
expectImmediate: true,
166+
},
167+
{
168+
name: "GET /v1/models?model=auto - should return immediate response",
169+
method: "GET",
170+
path: "/v1/models?model=auto",
171+
expectImmediate: true,
172+
},
173+
{
174+
name: "POST /v1/chat/completions - should continue processing",
175+
method: "POST",
176+
path: "/v1/chat/completions",
177+
expectImmediate: false,
178+
},
179+
{
180+
name: "POST /v1/models - should continue processing",
181+
method: "POST",
182+
path: "/v1/models",
183+
expectImmediate: false,
184+
},
185+
}
186+
187+
for _, tt := range tests {
188+
t.Run(tt.name, func(t *testing.T) {
189+
// Create request headers
190+
requestHeaders := &ext_proc.ProcessingRequest_RequestHeaders{
191+
RequestHeaders: &ext_proc.HttpHeaders{
192+
Headers: &core.HeaderMap{
193+
Headers: []*core.HeaderValue{
194+
{
195+
Key: ":method",
196+
Value: tt.method,
197+
},
198+
{
199+
Key: ":path",
200+
Value: tt.path,
201+
},
202+
{
203+
Key: "content-type",
204+
Value: "application/json",
205+
},
206+
},
207+
},
208+
},
209+
}
210+
211+
ctx := &RequestContext{
212+
Headers: make(map[string]string),
213+
}
214+
215+
response, err := router.handleRequestHeaders(requestHeaders, ctx)
216+
if err != nil {
217+
t.Fatalf("handleRequestHeaders failed: %v", err)
218+
}
219+
220+
if tt.expectImmediate {
221+
// Should return immediate response
222+
if response.GetImmediateResponse() == nil {
223+
t.Error("Expected immediate response for /v1/models endpoint")
224+
}
225+
} else {
226+
// Should return continue response
227+
if response.GetRequestHeaders() == nil {
228+
t.Error("Expected request headers response for non-models endpoint")
229+
}
230+
if response.GetRequestHeaders().Response.Status != ext_proc.CommonResponse_CONTINUE {
231+
t.Error("Expected CONTINUE status for non-models endpoint")
232+
}
233+
}
234+
})
235+
}
236+
}

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
99
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
10+
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
1011
"github.com/openai/openai-go"
1112
"google.golang.org/grpc/codes"
1213
"google.golang.org/grpc/status"
@@ -209,6 +210,15 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
209210
}
210211
}
211212

213+
// Check if this is a GET request to /v1/models
214+
method := ctx.Headers[":method"]
215+
path := ctx.Headers[":path"]
216+
217+
if method == "GET" && strings.HasPrefix(path, "/v1/models") {
218+
observability.Infof("Handling /v1/models request with path: %s", path)
219+
return r.handleModelsRequest(path)
220+
}
221+
212222
// Prepare base response
213223
response := &ext_proc.ProcessingResponse{
214224
Response: &ext_proc.ProcessingResponse_RequestHeaders{
@@ -821,3 +831,127 @@ func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompleti
821831

822832
return nil
823833
}
834+
835+
// OpenAIModel represents a single model in the OpenAI /v1/models response
836+
type OpenAIModel struct {
837+
ID string `json:"id"`
838+
Object string `json:"object"`
839+
Created int64 `json:"created"`
840+
OwnedBy string `json:"owned_by"`
841+
}
842+
843+
// OpenAIModelList is the container for the models list response
844+
type OpenAIModelList struct {
845+
Object string `json:"object"`
846+
Data []OpenAIModel `json:"data"`
847+
}
848+
849+
// handleModelsRequest handles GET /v1/models requests and returns a direct response
850+
func (r *OpenAIRouter) handleModelsRequest(path string) (*ext_proc.ProcessingResponse, error) {
851+
now := time.Now().Unix()
852+
853+
// Start with the special "auto" model always available from the router
854+
models := []OpenAIModel{
855+
{
856+
ID: "auto",
857+
Object: "model",
858+
Created: now,
859+
OwnedBy: "vllm-semantic-router",
860+
},
861+
}
862+
863+
// Append underlying models from config (if available)
864+
if r.Config != nil {
865+
for _, m := range r.Config.GetAllModels() {
866+
// Skip if already added as "auto" (or avoid duplicates in general)
867+
if m == "auto" {
868+
continue
869+
}
870+
models = append(models, OpenAIModel{
871+
ID: m,
872+
Object: "model",
873+
Created: now,
874+
OwnedBy: "vllm-semantic-router",
875+
})
876+
}
877+
}
878+
879+
resp := OpenAIModelList{
880+
Object: "list",
881+
Data: models,
882+
}
883+
884+
return r.createJSONResponse(200, resp), nil
885+
}
886+
887+
// statusCodeToEnum converts HTTP status code to typev3.StatusCode enum
888+
func statusCodeToEnum(statusCode int) typev3.StatusCode {
889+
switch statusCode {
890+
case 200:
891+
return typev3.StatusCode_OK
892+
case 400:
893+
return typev3.StatusCode_BadRequest
894+
case 404:
895+
return typev3.StatusCode_NotFound
896+
case 500:
897+
return typev3.StatusCode_InternalServerError
898+
default:
899+
return typev3.StatusCode_OK
900+
}
901+
}
902+
903+
// createJSONResponseWithBody creates a direct response with pre-marshaled JSON body
904+
func (r *OpenAIRouter) createJSONResponseWithBody(statusCode int, jsonBody []byte) *ext_proc.ProcessingResponse {
905+
return &ext_proc.ProcessingResponse{
906+
Response: &ext_proc.ProcessingResponse_ImmediateResponse{
907+
ImmediateResponse: &ext_proc.ImmediateResponse{
908+
Status: &typev3.HttpStatus{
909+
Code: statusCodeToEnum(statusCode),
910+
},
911+
Headers: &ext_proc.HeaderMutation{
912+
SetHeaders: []*core.HeaderValueOption{
913+
{
914+
Header: &core.HeaderValue{
915+
Key: "content-type",
916+
RawValue: []byte("application/json"),
917+
},
918+
},
919+
},
920+
},
921+
Body: jsonBody,
922+
},
923+
},
924+
}
925+
}
926+
927+
// createJSONResponse creates a direct response with JSON content
928+
func (r *OpenAIRouter) createJSONResponse(statusCode int, data interface{}) *ext_proc.ProcessingResponse {
929+
jsonData, err := json.Marshal(data)
930+
if err != nil {
931+
observability.Errorf("Failed to marshal JSON response: %v", err)
932+
return r.createErrorResponse(500, "Internal server error")
933+
}
934+
935+
return r.createJSONResponseWithBody(statusCode, jsonData)
936+
}
937+
938+
// createErrorResponse creates a direct error response
939+
func (r *OpenAIRouter) createErrorResponse(statusCode int, message string) *ext_proc.ProcessingResponse {
940+
errorResp := map[string]interface{}{
941+
"error": map[string]interface{}{
942+
"message": message,
943+
"type": "invalid_request_error",
944+
"code": statusCode,
945+
},
946+
}
947+
948+
jsonData, err := json.Marshal(errorResp)
949+
if err != nil {
950+
observability.Errorf("Failed to marshal error response: %v", err)
951+
jsonData = []byte(`{"error":{"message":"Internal server error","type":"internal_error","code":500}}`)
952+
// Use 500 status code for fallback error
953+
statusCode = 500
954+
}
955+
956+
return r.createJSONResponseWithBody(statusCode, jsonData)
957+
}

0 commit comments

Comments
 (0)