Skip to content

Commit e5d643f

Browse files
committed
feat: support /v1/models in direct response
Signed-off-by: bitliu <[email protected]>
1 parent b3658bc commit e5d643f

File tree

2 files changed

+429
-0
lines changed

2 files changed

+429
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
package extproc
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
8+
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
9+
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
10+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
11+
)
12+
13+
func TestHandleModelsRequest(t *testing.T) {
14+
// Create a test router with mock config
15+
cfg := &config.RouterConfig{
16+
VLLMEndpoints: []config.VLLMEndpoint{
17+
{
18+
Name: "primary",
19+
Address: "127.0.0.1",
20+
Port: 8000,
21+
Models: []string{"gpt-4o-mini", "llama-3.1-8b-instruct"},
22+
Weight: 1,
23+
},
24+
},
25+
}
26+
27+
router := &OpenAIRouter{
28+
Config: cfg,
29+
}
30+
31+
tests := []struct {
32+
name string
33+
path string
34+
expectedModels []string
35+
expectedCount int
36+
}{
37+
{
38+
name: "GET /v1/models - all models",
39+
path: "/v1/models",
40+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
41+
expectedCount: 3,
42+
},
43+
{
44+
name: "GET /v1/models?model=auto - only auto model",
45+
path: "/v1/models?model=auto",
46+
expectedModels: []string{"auto"},
47+
expectedCount: 1,
48+
},
49+
{
50+
name: "GET /v1/models?model=gpt-4o-mini - all models (no filtering)",
51+
path: "/v1/models?model=gpt-4o-mini",
52+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
53+
expectedCount: 3,
54+
},
55+
{
56+
name: "GET /v1/models?model= - all models (empty param)",
57+
path: "/v1/models?model=",
58+
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
59+
expectedCount: 3,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
response, err := router.handleModelsRequest(tt.path)
66+
if err != nil {
67+
t.Fatalf("handleModelsRequest failed: %v", err)
68+
}
69+
70+
// Verify it's an immediate response
71+
immediateResp := response.GetImmediateResponse()
72+
if immediateResp == nil {
73+
t.Fatal("Expected immediate response, got nil")
74+
}
75+
76+
// Verify status code is 200 OK
77+
if immediateResp.Status.Code != typev3.StatusCode_OK {
78+
t.Errorf("Expected status code OK, got %v", immediateResp.Status.Code)
79+
}
80+
81+
// Verify content-type header
82+
found := false
83+
for _, header := range immediateResp.Headers.SetHeaders {
84+
if header.Header.Key == "content-type" {
85+
if string(header.Header.RawValue) != "application/json" {
86+
t.Errorf("Expected content-type application/json, got %s", string(header.Header.RawValue))
87+
}
88+
found = true
89+
break
90+
}
91+
}
92+
if !found {
93+
t.Error("Expected content-type header not found")
94+
}
95+
96+
// Parse response body
97+
var modelList OpenAIModelList
98+
if err := json.Unmarshal(immediateResp.Body, &modelList); err != nil {
99+
t.Fatalf("Failed to parse response body: %v", err)
100+
}
101+
102+
// Verify response structure
103+
if modelList.Object != "list" {
104+
t.Errorf("Expected object 'list', got %s", modelList.Object)
105+
}
106+
107+
if len(modelList.Data) != tt.expectedCount {
108+
t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList.Data))
109+
}
110+
111+
// Verify expected models are present
112+
modelMap := make(map[string]bool)
113+
for _, model := range modelList.Data {
114+
modelMap[model.ID] = true
115+
116+
// Verify model structure
117+
if model.Object != "model" {
118+
t.Errorf("Expected model object 'model', got %s", model.Object)
119+
}
120+
if model.Created == 0 {
121+
t.Error("Expected non-zero created timestamp")
122+
}
123+
if model.ID == "auto" && model.OwnedBy != "semantic-router" {
124+
t.Errorf("Expected auto model owned_by 'semantic-router', got %s", model.OwnedBy)
125+
}
126+
if model.ID != "auto" && model.OwnedBy != "upstream-endpoint" {
127+
t.Errorf("Expected non-auto model owned_by 'upstream-endpoint', got %s", model.OwnedBy)
128+
}
129+
}
130+
131+
for _, expectedModel := range tt.expectedModels {
132+
if !modelMap[expectedModel] {
133+
t.Errorf("Expected model %s not found in response", expectedModel)
134+
}
135+
}
136+
})
137+
}
138+
}
139+
140+
func TestHandleRequestHeadersWithModelsEndpoint(t *testing.T) {
141+
// Create a test router
142+
cfg := &config.RouterConfig{
143+
VLLMEndpoints: []config.VLLMEndpoint{
144+
{
145+
Name: "primary",
146+
Address: "127.0.0.1",
147+
Port: 8000,
148+
Models: []string{"gpt-4o-mini"},
149+
Weight: 1,
150+
},
151+
},
152+
}
153+
154+
router := &OpenAIRouter{
155+
Config: cfg,
156+
}
157+
158+
tests := []struct {
159+
name string
160+
method string
161+
path string
162+
expectImmediate bool
163+
}{
164+
{
165+
name: "GET /v1/models - should return immediate response",
166+
method: "GET",
167+
path: "/v1/models",
168+
expectImmediate: true,
169+
},
170+
{
171+
name: "GET /v1/models?model=auto - should return immediate response",
172+
method: "GET",
173+
path: "/v1/models?model=auto",
174+
expectImmediate: true,
175+
},
176+
{
177+
name: "POST /v1/chat/completions - should continue processing",
178+
method: "POST",
179+
path: "/v1/chat/completions",
180+
expectImmediate: false,
181+
},
182+
{
183+
name: "POST /v1/models - should continue processing",
184+
method: "POST",
185+
path: "/v1/models",
186+
expectImmediate: false,
187+
},
188+
}
189+
190+
for _, tt := range tests {
191+
t.Run(tt.name, func(t *testing.T) {
192+
// Create request headers
193+
requestHeaders := &ext_proc.ProcessingRequest_RequestHeaders{
194+
RequestHeaders: &ext_proc.HttpHeaders{
195+
Headers: &core.HeaderMap{
196+
Headers: []*core.HeaderValue{
197+
{
198+
Key: ":method",
199+
Value: tt.method,
200+
},
201+
{
202+
Key: ":path",
203+
Value: tt.path,
204+
},
205+
{
206+
Key: "content-type",
207+
Value: "application/json",
208+
},
209+
},
210+
},
211+
},
212+
}
213+
214+
ctx := &RequestContext{
215+
Headers: make(map[string]string),
216+
}
217+
218+
response, err := router.handleRequestHeaders(requestHeaders, ctx)
219+
if err != nil {
220+
t.Fatalf("handleRequestHeaders failed: %v", err)
221+
}
222+
223+
if tt.expectImmediate {
224+
// Should return immediate response
225+
if response.GetImmediateResponse() == nil {
226+
t.Error("Expected immediate response for /v1/models endpoint")
227+
}
228+
} else {
229+
// Should return continue response
230+
if response.GetRequestHeaders() == nil {
231+
t.Error("Expected request headers response for non-models endpoint")
232+
}
233+
if response.GetRequestHeaders().Response.Status != ext_proc.CommonResponse_CONTINUE {
234+
t.Error("Expected CONTINUE status for non-models endpoint")
235+
}
236+
}
237+
})
238+
}
239+
}

0 commit comments

Comments
 (0)