diff --git a/src/semantic-router/pkg/extproc/models_endpoint_test.go b/src/semantic-router/pkg/extproc/models_endpoint_test.go new file mode 100644 index 00000000..9fbd5d17 --- /dev/null +++ b/src/semantic-router/pkg/extproc/models_endpoint_test.go @@ -0,0 +1,236 @@ +package extproc + +import ( + "encoding/json" + "testing" + + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +func TestHandleModelsRequest(t *testing.T) { + // Create a test router with mock config + cfg := &config.RouterConfig{ + VLLMEndpoints: []config.VLLMEndpoint{ + { + Name: "primary", + Address: "127.0.0.1", + Port: 8000, + Models: []string{"gpt-4o-mini", "llama-3.1-8b-instruct"}, + Weight: 1, + }, + }, + } + + router := &OpenAIRouter{ + Config: cfg, + } + + tests := []struct { + name string + path string + expectedModels []string + expectedCount int + }{ + { + name: "GET /v1/models - all models", + path: "/v1/models", + expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"}, + expectedCount: 3, + }, + { + name: "GET /v1/models?model=auto - all models (no filtering implemented)", + path: "/v1/models?model=auto", + expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"}, + expectedCount: 3, + }, + { + name: "GET /v1/models?model=gpt-4o-mini - all models (no filtering)", + path: "/v1/models?model=gpt-4o-mini", + expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"}, + expectedCount: 3, + }, + { + name: "GET /v1/models?model= - all models (empty param)", + path: "/v1/models?model=", + expectedModels: []string{"auto", "gpt-4o-mini", "llama-3.1-8b-instruct"}, + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response, err := router.handleModelsRequest(tt.path) + if err != nil { + t.Fatalf("handleModelsRequest failed: %v", err) + } + + // Verify it's an immediate response + immediateResp := response.GetImmediateResponse() + if immediateResp == nil { + t.Fatal("Expected immediate response, got nil") + } + + // Verify status code is 200 OK + if immediateResp.Status.Code != typev3.StatusCode_OK { + t.Errorf("Expected status code OK, got %v", immediateResp.Status.Code) + } + + // Verify content-type header + found := false + for _, header := range immediateResp.Headers.SetHeaders { + if header.Header.Key == "content-type" { + if string(header.Header.RawValue) != "application/json" { + t.Errorf("Expected content-type application/json, got %s", string(header.Header.RawValue)) + } + found = true + break + } + } + if !found { + t.Error("Expected content-type header not found") + } + + // Parse response body + var modelList OpenAIModelList + if err := json.Unmarshal(immediateResp.Body, &modelList); err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + + // Verify response structure + if modelList.Object != "list" { + t.Errorf("Expected object 'list', got %s", modelList.Object) + } + + if len(modelList.Data) != tt.expectedCount { + t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList.Data)) + } + + // Verify expected models are present + modelMap := make(map[string]bool) + for _, model := range modelList.Data { + modelMap[model.ID] = true + + // Verify model structure + if model.Object != "model" { + t.Errorf("Expected model object 'model', got %s", model.Object) + } + if model.Created == 0 { + t.Error("Expected non-zero created timestamp") + } + if model.OwnedBy != "vllm-semantic-router" { + t.Errorf("Expected model owned_by 'vllm-semantic-router', got %s", model.OwnedBy) + } + } + + for _, expectedModel := range tt.expectedModels { + if !modelMap[expectedModel] { + t.Errorf("Expected model %s not found in response", expectedModel) + } + } + }) + } +} + +func TestHandleRequestHeadersWithModelsEndpoint(t *testing.T) { + // Create a test router + cfg := &config.RouterConfig{ + VLLMEndpoints: []config.VLLMEndpoint{ + { + Name: "primary", + Address: "127.0.0.1", + Port: 8000, + Models: []string{"gpt-4o-mini"}, + Weight: 1, + }, + }, + } + + router := &OpenAIRouter{ + Config: cfg, + } + + tests := []struct { + name string + method string + path string + expectImmediate bool + }{ + { + name: "GET /v1/models - should return immediate response", + method: "GET", + path: "/v1/models", + expectImmediate: true, + }, + { + name: "GET /v1/models?model=auto - should return immediate response", + method: "GET", + path: "/v1/models?model=auto", + expectImmediate: true, + }, + { + name: "POST /v1/chat/completions - should continue processing", + method: "POST", + path: "/v1/chat/completions", + expectImmediate: false, + }, + { + name: "POST /v1/models - should continue processing", + method: "POST", + path: "/v1/models", + expectImmediate: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create request headers + requestHeaders := &ext_proc.ProcessingRequest_RequestHeaders{ + RequestHeaders: &ext_proc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: ":method", + Value: tt.method, + }, + { + Key: ":path", + Value: tt.path, + }, + { + Key: "content-type", + Value: "application/json", + }, + }, + }, + }, + } + + ctx := &RequestContext{ + Headers: make(map[string]string), + } + + response, err := router.handleRequestHeaders(requestHeaders, ctx) + if err != nil { + t.Fatalf("handleRequestHeaders failed: %v", err) + } + + if tt.expectImmediate { + // Should return immediate response + if response.GetImmediateResponse() == nil { + t.Error("Expected immediate response for /v1/models endpoint") + } + } else { + // Should return continue response + if response.GetRequestHeaders() == nil { + t.Error("Expected request headers response for non-models endpoint") + } + if response.GetRequestHeaders().Response.Status != ext_proc.CommonResponse_CONTINUE { + t.Error("Expected CONTINUE status for non-models endpoint") + } + } + }) + } +} diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index b6efea23..36aad0f3 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -7,6 +7,7 @@ import ( core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/openai/openai-go" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -209,6 +210,15 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques } } + // Check if this is a GET request to /v1/models + method := ctx.Headers[":method"] + path := ctx.Headers[":path"] + + if method == "GET" && strings.HasPrefix(path, "/v1/models") { + observability.Infof("Handling /v1/models request with path: %s", path) + return r.handleModelsRequest(path) + } + // Prepare base response response := &ext_proc.ProcessingResponse{ Response: &ext_proc.ProcessingResponse_RequestHeaders{ @@ -821,3 +831,127 @@ func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompleti return nil } + +// OpenAIModel represents a single model in the OpenAI /v1/models response +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +// OpenAIModelList is the container for the models list response +type OpenAIModelList struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` +} + +// handleModelsRequest handles GET /v1/models requests and returns a direct response +func (r *OpenAIRouter) handleModelsRequest(path string) (*ext_proc.ProcessingResponse, error) { + now := time.Now().Unix() + + // Start with the special "auto" model always available from the router + models := []OpenAIModel{ + { + ID: "auto", + Object: "model", + Created: now, + OwnedBy: "vllm-semantic-router", + }, + } + + // Append underlying models from config (if available) + if r.Config != nil { + for _, m := range r.Config.GetAllModels() { + // Skip if already added as "auto" (or avoid duplicates in general) + if m == "auto" { + continue + } + models = append(models, OpenAIModel{ + ID: m, + Object: "model", + Created: now, + OwnedBy: "vllm-semantic-router", + }) + } + } + + resp := OpenAIModelList{ + Object: "list", + Data: models, + } + + return r.createJSONResponse(200, resp), nil +} + +// statusCodeToEnum converts HTTP status code to typev3.StatusCode enum +func statusCodeToEnum(statusCode int) typev3.StatusCode { + switch statusCode { + case 200: + return typev3.StatusCode_OK + case 400: + return typev3.StatusCode_BadRequest + case 404: + return typev3.StatusCode_NotFound + case 500: + return typev3.StatusCode_InternalServerError + default: + return typev3.StatusCode_OK + } +} + +// createJSONResponseWithBody creates a direct response with pre-marshaled JSON body +func (r *OpenAIRouter) createJSONResponseWithBody(statusCode int, jsonBody []byte) *ext_proc.ProcessingResponse { + return &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &ext_proc.ImmediateResponse{ + Status: &typev3.HttpStatus{ + Code: statusCodeToEnum(statusCode), + }, + Headers: &ext_proc.HeaderMutation{ + SetHeaders: []*core.HeaderValueOption{ + { + Header: &core.HeaderValue{ + Key: "content-type", + RawValue: []byte("application/json"), + }, + }, + }, + }, + Body: jsonBody, + }, + }, + } +} + +// createJSONResponse creates a direct response with JSON content +func (r *OpenAIRouter) createJSONResponse(statusCode int, data interface{}) *ext_proc.ProcessingResponse { + jsonData, err := json.Marshal(data) + if err != nil { + observability.Errorf("Failed to marshal JSON response: %v", err) + return r.createErrorResponse(500, "Internal server error") + } + + return r.createJSONResponseWithBody(statusCode, jsonData) +} + +// createErrorResponse creates a direct error response +func (r *OpenAIRouter) createErrorResponse(statusCode int, message string) *ext_proc.ProcessingResponse { + errorResp := map[string]interface{}{ + "error": map[string]interface{}{ + "message": message, + "type": "invalid_request_error", + "code": statusCode, + }, + } + + jsonData, err := json.Marshal(errorResp) + if err != nil { + observability.Errorf("Failed to marshal error response: %v", err) + jsonData = []byte(`{"error":{"message":"Internal server error","type":"internal_error","code":500}}`) + // Use 500 status code for fallback error + statusCode = 500 + } + + return r.createJSONResponseWithBody(statusCode, jsonData) +}