Skip to content

Commit 3bf92d4

Browse files
committed
feat: support /v1/models in direct response
Signed-off-by: bitliu <[email protected]>
1 parent b3658bc commit 3bf92d4

File tree

2 files changed

+370
-0
lines changed

2 files changed

+370
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package extproc
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
8+
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
9+
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
10+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
11+
)
12+
13+
func TestHandleModelsRequest(t *testing.T) {
14+
// Create a test router with mock config
15+
cfg := &config.RouterConfig{
16+
VLLMEndpoints: []config.VLLMEndpoint{
17+
{
18+
Name: "primary",
19+
Address: "127.0.0.1",
20+
Port: 8000,
21+
Models: []string{"gpt-4o-mini", "llama-3.1-8b-instruct"},
22+
Weight: 1,
23+
},
24+
},
25+
}
26+
27+
router := &OpenAIRouter{
28+
Config: cfg,
29+
}
30+
31+
tests := []struct {
32+
name string
33+
path string
34+
expectedModels []string
35+
expectedCount int
36+
}{
37+
{
38+
name: "GET /v1/models - all models",
39+
path: "/v1/models",
40+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
41+
expectedCount: 3,
42+
},
43+
{
44+
name: "GET /v1/models?model=auto - all models (no filtering implemented)",
45+
path: "/v1/models?model=auto",
46+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
47+
expectedCount: 3,
48+
},
49+
{
50+
name: "GET /v1/models?model=gpt-4o-mini - all models (no filtering)",
51+
path: "/v1/models?model=gpt-4o-mini",
52+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
53+
expectedCount: 3,
54+
},
55+
{
56+
name: "GET /v1/models?model= - all models (empty param)",
57+
path: "/v1/models?model=",
58+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
59+
expectedCount: 3,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
response, err := router.handleModelsRequest(tt.path)
66+
if err != nil {
67+
t.Fatalf("handleModelsRequest failed: %v", err)
68+
}
69+
70+
// Verify it's an immediate response
71+
immediateResp := response.GetImmediateResponse()
72+
if immediateResp == nil {
73+
t.Fatal("Expected immediate response, got nil")
74+
}
75+
76+
// Verify status code is 200 OK
77+
if immediateResp.Status.Code != typev3.StatusCode_OK {
78+
t.Errorf("Expected status code OK, got %v", immediateResp.Status.Code)
79+
}
80+
81+
// Verify content-type header
82+
found := false
83+
for _, header := range immediateResp.Headers.SetHeaders {
84+
if header.Header.Key == "content-type" {
85+
if string(header.Header.RawValue) != "application/json" {
86+
t.Errorf("Expected content-type application/json, got %s", string(header.Header.RawValue))
87+
}
88+
found = true
89+
break
90+
}
91+
}
92+
if !found {
93+
t.Error("Expected content-type header not found")
94+
}
95+
96+
// Parse response body
97+
var modelList OpenAIModelList
98+
if err := json.Unmarshal(immediateResp.Body, &modelList); err != nil {
99+
t.Fatalf("Failed to parse response body: %v", err)
100+
}
101+
102+
// Verify response structure
103+
if modelList.Object != "list" {
104+
t.Errorf("Expected object 'list', got %s", modelList.Object)
105+
}
106+
107+
if len(modelList.Data) != tt.expectedCount {
108+
t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList.Data))
109+
}
110+
111+
// Verify expected models are present
112+
modelMap := make(map[string]bool)
113+
for _, model := range modelList.Data {
114+
modelMap[model.ID] = true
115+
116+
// Verify model structure
117+
if model.Object != "model" {
118+
t.Errorf("Expected model object 'model', got %s", model.Object)
119+
}
120+
if model.Created == 0 {
121+
t.Error("Expected non-zero created timestamp")
122+
}
123+
if model.OwnedBy != "vllm-semantic-router" {
124+
t.Errorf("Expected model owned_by 'vllm-semantic-router', got %s", model.OwnedBy)
125+
}
126+
}
127+
128+
for _, expectedModel := range tt.expectedModels {
129+
if !modelMap[expectedModel] {
130+
t.Errorf("Expected model %s not found in response", expectedModel)
131+
}
132+
}
133+
})
134+
}
135+
}
136+
137+
func TestHandleRequestHeadersWithModelsEndpoint(t *testing.T) {
138+
// Create a test router
139+
cfg := &config.RouterConfig{
140+
VLLMEndpoints: []config.VLLMEndpoint{
141+
{
142+
Name: "primary",
143+
Address: "127.0.0.1",
144+
Port: 8000,
145+
Models: []string{"gpt-4o-mini"},
146+
Weight: 1,
147+
},
148+
},
149+
}
150+
151+
router := &OpenAIRouter{
152+
Config: cfg,
153+
}
154+
155+
tests := []struct {
156+
name string
157+
method string
158+
path string
159+
expectImmediate bool
160+
}{
161+
{
162+
name: "GET /v1/models - should return immediate response",
163+
method: "GET",
164+
path: "/v1/models",
165+
expectImmediate: true,
166+
},
167+
{
168+
name: "GET /v1/models?model=auto - should return immediate response",
169+
method: "GET",
170+
path: "/v1/models?model=auto",
171+
expectImmediate: true,
172+
},
173+
{
174+
name: "POST /v1/chat/completions - should continue processing",
175+
method: "POST",
176+
path: "/v1/chat/completions",
177+
expectImmediate: false,
178+
},
179+
{
180+
name: "POST /v1/models - should continue processing",
181+
method: "POST",
182+
path: "/v1/models",
183+
expectImmediate: false,
184+
},
185+
}
186+
187+
for _, tt := range tests {
188+
t.Run(tt.name, func(t *testing.T) {
189+
// Create request headers
190+
requestHeaders := &ext_proc.ProcessingRequest_RequestHeaders{
191+
RequestHeaders: &ext_proc.HttpHeaders{
192+
Headers: &core.HeaderMap{
193+
Headers: []*core.HeaderValue{
194+
{
195+
Key: ":method",
196+
Value: tt.method,
197+
},
198+
{
199+
Key: ":path",
200+
Value: tt.path,
201+
},
202+
{
203+
Key: "content-type",
204+
Value: "application/json",
205+
},
206+
},
207+
},
208+
},
209+
}
210+
211+
ctx := &RequestContext{
212+
Headers: make(map[string]string),
213+
}
214+
215+
response, err := router.handleRequestHeaders(requestHeaders, ctx)
216+
if err != nil {
217+
t.Fatalf("handleRequestHeaders failed: %v", err)
218+
}
219+
220+
if tt.expectImmediate {
221+
// Should return immediate response
222+
if response.GetImmediateResponse() == nil {
223+
t.Error("Expected immediate response for /v1/models endpoint")
224+
}
225+
} else {
226+
// Should return continue response
227+
if response.GetRequestHeaders() == nil {
228+
t.Error("Expected request headers response for non-models endpoint")
229+
}
230+
if response.GetRequestHeaders().Response.Status != ext_proc.CommonResponse_CONTINUE {
231+
t.Error("Expected CONTINUE status for non-models endpoint")
232+
}
233+
}
234+
})
235+
}
236+
}

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
99
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
10+
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
1011
"github.com/openai/openai-go"
1112
"google.golang.org/grpc/codes"
1213
"google.golang.org/grpc/status"
@@ -203,6 +204,15 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
203204
}
204205
}
205206

207+
// Check if this is a GET request to /v1/models
208+
method := ctx.Headers[":method"]
209+
path := ctx.Headers[":path"]
210+
211+
if method == "GET" && strings.HasPrefix(path, "/v1/models") {
212+
observability.Infof("Handling /v1/models request with path: %s", path)
213+
return r.handleModelsRequest(path)
214+
}
215+
206216
// Prepare base response
207217
response := &ext_proc.ProcessingResponse{
208218
Response: &ext_proc.ProcessingResponse_RequestHeaders{
@@ -801,3 +811,127 @@ func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompleti
801811

802812
return nil
803813
}
814+
815+
// OpenAIModel represents a single model in the OpenAI /v1/models response
816+
type OpenAIModel struct {
817+
ID string `json:"id"`
818+
Object string `json:"object"`
819+
Created int64 `json:"created"`
820+
OwnedBy string `json:"owned_by"`
821+
}
822+
823+
// OpenAIModelList is the container for the models list response
824+
type OpenAIModelList struct {
825+
Object string `json:"object"`
826+
Data []OpenAIModel `json:"data"`
827+
}
828+
829+
// handleModelsRequest handles GET /v1/models requests and returns a direct response
830+
func (r *OpenAIRouter) handleModelsRequest(path string) (*ext_proc.ProcessingResponse, error) {
831+
now := time.Now().Unix()
832+
833+
// Start with the special "auto" model always available from the router
834+
models := []OpenAIModel{
835+
{
836+
ID: "auto",
837+
Object: "model",
838+
Created: now,
839+
OwnedBy: "vllm-semantic-router",
840+
},
841+
}
842+
843+
// Append underlying models from config (if available)
844+
if r.Config != nil {
845+
for _, m := range r.Config.GetAllModels() {
846+
// Skip if already added as "auto" (or avoid duplicates in general)
847+
if m == "auto" {
848+
continue
849+
}
850+
models = append(models, OpenAIModel{
851+
ID: m,
852+
Object: "model",
853+
Created: now,
854+
OwnedBy: "vllm-semantic-router",
855+
})
856+
}
857+
}
858+
859+
resp := OpenAIModelList{
860+
Object: "list",
861+
Data: models,
862+
}
863+
864+
return r.createJSONResponse(200, resp), nil
865+
}
866+
867+
// statusCodeToEnum converts HTTP status code to typev3.StatusCode enum
868+
func statusCodeToEnum(statusCode int) typev3.StatusCode {
869+
switch statusCode {
870+
case 200:
871+
return typev3.StatusCode_OK
872+
case 400:
873+
return typev3.StatusCode_BadRequest
874+
case 404:
875+
return typev3.StatusCode_NotFound
876+
case 500:
877+
return typev3.StatusCode_InternalServerError
878+
default:
879+
return typev3.StatusCode_OK
880+
}
881+
}
882+
883+
// createJSONResponseWithBody creates a direct response with pre-marshaled JSON body
884+
func (r *OpenAIRouter) createJSONResponseWithBody(statusCode int, jsonBody []byte) *ext_proc.ProcessingResponse {
885+
return &ext_proc.ProcessingResponse{
886+
Response: &ext_proc.ProcessingResponse_ImmediateResponse{
887+
ImmediateResponse: &ext_proc.ImmediateResponse{
888+
Status: &typev3.HttpStatus{
889+
Code: statusCodeToEnum(statusCode),
890+
},
891+
Headers: &ext_proc.HeaderMutation{
892+
SetHeaders: []*core.HeaderValueOption{
893+
{
894+
Header: &core.HeaderValue{
895+
Key: "content-type",
896+
RawValue: []byte("application/json"),
897+
},
898+
},
899+
},
900+
},
901+
Body: jsonBody,
902+
},
903+
},
904+
}
905+
}
906+
907+
// createJSONResponse creates a direct response with JSON content
908+
func (r *OpenAIRouter) createJSONResponse(statusCode int, data interface{}) *ext_proc.ProcessingResponse {
909+
jsonData, err := json.Marshal(data)
910+
if err != nil {
911+
observability.Errorf("Failed to marshal JSON response: %v", err)
912+
return r.createErrorResponse(500, "Internal server error")
913+
}
914+
915+
return r.createJSONResponseWithBody(statusCode, jsonData)
916+
}
917+
918+
// createErrorResponse creates a direct error response
919+
func (r *OpenAIRouter) createErrorResponse(statusCode int, message string) *ext_proc.ProcessingResponse {
920+
errorResp := map[string]interface{}{
921+
"error": map[string]interface{}{
922+
"message": message,
923+
"type": "invalid_request_error",
924+
"code": statusCode,
925+
},
926+
}
927+
928+
jsonData, err := json.Marshal(errorResp)
929+
if err != nil {
930+
observability.Errorf("Failed to marshal error response: %v", err)
931+
jsonData = []byte(`{"error":{"message":"Internal server error","type":"internal_error","code":500}}`)
932+
// Use 500 status code for fallback error
933+
statusCode = 500
934+
}
935+
936+
return r.createJSONResponseWithBody(statusCode, jsonData)
937+
}

0 commit comments

Comments
 (0)