Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions src/semantic-router/pkg/extproc/models_endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package extproc

import (
"encoding/json"
"testing"

core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
)

func TestHandleModelsRequest(t *testing.T) {
// Create a test router with mock config
cfg := &config.RouterConfig{
VLLMEndpoints: []config.VLLMEndpoint{
{
Name: "primary",
Address: "127.0.0.1",
Port: 8000,
Models: []string{"gpt-4o-mini", "llama-3.1-8b-instruct"},
Weight: 1,
},
},
}

router := &OpenAIRouter{
Config: cfg,
}

tests := []struct {
name string
path string
expectedModels []string
expectedCount int
}{
{
name: "GET /v1/models - all models",
path: "/v1/models",
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
expectedCount: 3,
},
{
name: "GET /v1/models?model=auto - all models (no filtering implemented)",
path: "/v1/models?model=auto",
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
expectedCount: 3,
},
{
name: "GET /v1/models?model=gpt-4o-mini - all models (no filtering)",
path: "/v1/models?model=gpt-4o-mini",
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
expectedCount: 3,
},
{
name: "GET /v1/models?model= - all models (empty param)",
path: "/v1/models?model=",
expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"},
expectedCount: 3,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response, err := router.handleModelsRequest(tt.path)
if err != nil {
t.Fatalf("handleModelsRequest failed: %v", err)
}

// Verify it's an immediate response
immediateResp := response.GetImmediateResponse()
if immediateResp == nil {
t.Fatal("Expected immediate response, got nil")
}

// Verify status code is 200 OK
if immediateResp.Status.Code != typev3.StatusCode_OK {
t.Errorf("Expected status code OK, got %v", immediateResp.Status.Code)
}

// Verify content-type header
found := false
for _, header := range immediateResp.Headers.SetHeaders {
if header.Header.Key == "content-type" {
if string(header.Header.RawValue) != "application/json" {
t.Errorf("Expected content-type application/json, got %s", string(header.Header.RawValue))
}
found = true
break
}
}
if !found {
t.Error("Expected content-type header not found")
}

// Parse response body
var modelList OpenAIModelList
if err := json.Unmarshal(immediateResp.Body, &modelList); err != nil {
t.Fatalf("Failed to parse response body: %v", err)
}

// Verify response structure
if modelList.Object != "list" {
t.Errorf("Expected object 'list', got %s", modelList.Object)
}

if len(modelList.Data) != tt.expectedCount {
t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList.Data))
}

// Verify expected models are present
modelMap := make(map[string]bool)
for _, model := range modelList.Data {
modelMap[model.ID] = true

// Verify model structure
if model.Object != "model" {
t.Errorf("Expected model object 'model', got %s", model.Object)
}
if model.Created == 0 {
t.Error("Expected non-zero created timestamp")
}
if model.OwnedBy != "vllm-semantic-router" {
t.Errorf("Expected model owned_by 'vllm-semantic-router', got %s", model.OwnedBy)
}
}

for _, expectedModel := range tt.expectedModels {
if !modelMap[expectedModel] {
t.Errorf("Expected model %s not found in response", expectedModel)
}
}
})
}
}

func TestHandleRequestHeadersWithModelsEndpoint(t *testing.T) {
// Create a test router
cfg := &config.RouterConfig{
VLLMEndpoints: []config.VLLMEndpoint{
{
Name: "primary",
Address: "127.0.0.1",
Port: 8000,
Models: []string{"gpt-4o-mini"},
Weight: 1,
},
},
}

router := &OpenAIRouter{
Config: cfg,
}

tests := []struct {
name string
method string
path string
expectImmediate bool
}{
{
name: "GET /v1/models - should return immediate response",
method: "GET",
path: "/v1/models",
expectImmediate: true,
},
{
name: "GET /v1/models?model=auto - should return immediate response",
method: "GET",
path: "/v1/models?model=auto",
expectImmediate: true,
},
{
name: "POST /v1/chat/completions - should continue processing",
method: "POST",
path: "/v1/chat/completions",
expectImmediate: false,
},
{
name: "POST /v1/models - should continue processing",
method: "POST",
path: "/v1/models",
expectImmediate: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create request headers
requestHeaders := &ext_proc.ProcessingRequest_RequestHeaders{
RequestHeaders: &ext_proc.HttpHeaders{
Headers: &core.HeaderMap{
Headers: []*core.HeaderValue{
{
Key: ":method",
Value: tt.method,
},
{
Key: ":path",
Value: tt.path,
},
{
Key: "content-type",
Value: "application/json",
},
},
},
},
}

ctx := &RequestContext{
Headers: make(map[string]string),
}

response, err := router.handleRequestHeaders(requestHeaders, ctx)
if err != nil {
t.Fatalf("handleRequestHeaders failed: %v", err)
}

if tt.expectImmediate {
// Should return immediate response
if response.GetImmediateResponse() == nil {
t.Error("Expected immediate response for /v1/models endpoint")
}
} else {
// Should return continue response
if response.GetRequestHeaders() == nil {
t.Error("Expected request headers response for non-models endpoint")
}
if response.GetRequestHeaders().Response.Status != ext_proc.CommonResponse_CONTINUE {
t.Error("Expected CONTINUE status for non-models endpoint")
}
}
})
}
}
134 changes: 134 additions & 0 deletions src/semantic-router/pkg/extproc/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/openai/openai-go"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -209,6 +210,15 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
}
}

// Check if this is a GET request to /v1/models
method := ctx.Headers[":method"]
path := ctx.Headers[":path"]

if method == "GET" && strings.HasPrefix(path, "/v1/models") {
observability.Infof("Handling /v1/models request with path: %s", path)
return r.handleModelsRequest(path)
}

// Prepare base response
response := &ext_proc.ProcessingResponse{
Response: &ext_proc.ProcessingResponse_RequestHeaders{
Expand Down Expand Up @@ -821,3 +831,127 @@ func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompleti

return nil
}

// OpenAIModel represents a single model in the OpenAI /v1/models response
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}

// OpenAIModelList is the container for the models list response
type OpenAIModelList struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
}

// handleModelsRequest handles GET /v1/models requests and returns a direct response
func (r *OpenAIRouter) handleModelsRequest(path string) (*ext_proc.ProcessingResponse, error) {
now := time.Now().Unix()

// Start with the special "auto" model always available from the router
models := []OpenAIModel{
{
ID: "auto",
Object: "model",
Created: now,
OwnedBy: "vllm-semantic-router",
},
}

// Append underlying models from config (if available)
if r.Config != nil {
for _, m := range r.Config.GetAllModels() {
// Skip if already added as "auto" (or avoid duplicates in general)
if m == "auto" {
continue
}
models = append(models, OpenAIModel{
ID: m,
Object: "model",
Created: now,
OwnedBy: "vllm-semantic-router",
})
}
}

resp := OpenAIModelList{
Object: "list",
Data: models,
}

return r.createJSONResponse(200, resp), nil
}

// statusCodeToEnum converts HTTP status code to typev3.StatusCode enum
func statusCodeToEnum(statusCode int) typev3.StatusCode {
switch statusCode {
case 200:
return typev3.StatusCode_OK
case 400:
return typev3.StatusCode_BadRequest
case 404:
return typev3.StatusCode_NotFound
case 500:
return typev3.StatusCode_InternalServerError
default:
return typev3.StatusCode_OK
}
}

// createJSONResponseWithBody creates a direct response with pre-marshaled JSON body
func (r *OpenAIRouter) createJSONResponseWithBody(statusCode int, jsonBody []byte) *ext_proc.ProcessingResponse {
return &ext_proc.ProcessingResponse{
Response: &ext_proc.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &ext_proc.ImmediateResponse{
Status: &typev3.HttpStatus{
Code: statusCodeToEnum(statusCode),
},
Headers: &ext_proc.HeaderMutation{
SetHeaders: []*core.HeaderValueOption{
{
Header: &core.HeaderValue{
Key: "content-type",
RawValue: []byte("application/json"),
},
},
},
},
Body: jsonBody,
},
},
}
}

// createJSONResponse creates a direct response with JSON content
func (r *OpenAIRouter) createJSONResponse(statusCode int, data interface{}) *ext_proc.ProcessingResponse {
jsonData, err := json.Marshal(data)
if err != nil {
observability.Errorf("Failed to marshal JSON response: %v", err)
return r.createErrorResponse(500, "Internal server error")
}

return r.createJSONResponseWithBody(statusCode, jsonData)
}

// createErrorResponse creates a direct error response
func (r *OpenAIRouter) createErrorResponse(statusCode int, message string) *ext_proc.ProcessingResponse {
errorResp := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"type": "invalid_request_error",
"code": statusCode,
},
}

jsonData, err := json.Marshal(errorResp)
if err != nil {
observability.Errorf("Failed to marshal error response: %v", err)
jsonData = []byte(`{"error":{"message":"Internal server error","type":"internal_error","code":500}}`)
// Use 500 status code for fallback error
statusCode = 500
}

return r.createJSONResponseWithBody(statusCode, jsonData)
}
Loading