diff --git a/src/semantic-router/pkg/api/server.go b/src/semantic-router/pkg/api/server.go index d0611ef1..b41429b6 100644 --- a/src/semantic-router/pkg/api/server.go +++ b/src/semantic-router/pkg/api/server.go @@ -184,6 +184,13 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux { // Health check endpoint mux.HandleFunc("GET /health", s.handleHealth) + // API discovery endpoint + mux.HandleFunc("GET /api/v1", s.handleAPIOverview) + + // OpenAPI and documentation endpoints + mux.HandleFunc("GET /openapi.json", s.handleOpenAPISpec) + mux.HandleFunc("GET /docs", s.handleSwaggerUI) + // Classification endpoints mux.HandleFunc("POST /api/v1/classify/intent", s.handleIntentClassification) mux.HandleFunc("POST /api/v1/classify/pii", s.handlePIIDetection) @@ -224,6 +231,323 @@ func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, r *http.Re w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`)) } +// APIOverviewResponse represents the response for GET /api/v1 +type APIOverviewResponse struct { + Service string `json:"service"` + Version string `json:"version"` + Description string `json:"description"` + Endpoints []EndpointInfo `json:"endpoints"` + TaskTypes []TaskTypeInfo `json:"task_types"` + Links map[string]string `json:"links"` +} + +// EndpointInfo represents information about an API endpoint +type EndpointInfo struct { + Path string `json:"path"` + Method string `json:"method"` + Description string `json:"description"` +} + +// TaskTypeInfo represents information about a task type +type TaskTypeInfo struct { + Name string `json:"name"` + Description string `json:"description"` +} + +// EndpointMetadata stores metadata about an endpoint for API documentation +type EndpointMetadata struct { + Path string + Method string + Description string +} + +// endpointRegistry is a centralized registry of all API endpoints with their metadata +var endpointRegistry = []EndpointMetadata{ + {Path: "/health", Method: "GET", Description: "Health check endpoint"}, + {Path: "/api/v1", Method: "GET", Description: "API discovery and documentation"}, + {Path: "/openapi.json", Method: "GET", Description: "OpenAPI 3.0 specification"}, + {Path: "/docs", Method: "GET", Description: "Interactive Swagger UI documentation"}, + {Path: "/api/v1/classify/intent", Method: "POST", Description: "Classify user queries into routing categories"}, + {Path: "/api/v1/classify/pii", Method: "POST", Description: "Detect personally identifiable information in text"}, + {Path: "/api/v1/classify/security", Method: "POST", Description: "Detect jailbreak attempts and security threats"}, + {Path: "/api/v1/classify/combined", Method: "POST", Description: "Perform combined classification (intent, PII, and security)"}, + {Path: "/api/v1/classify/batch", Method: "POST", Description: "Batch classification with configurable task_type parameter"}, + {Path: "/info/models", Method: "GET", Description: "Get information about loaded models"}, + {Path: "/info/classifier", Method: "GET", Description: "Get classifier information and status"}, + {Path: "/v1/models", Method: "GET", Description: "OpenAI-compatible model listing"}, + {Path: "/metrics/classification", Method: "GET", Description: "Get classification metrics and statistics"}, + {Path: "/config/classification", Method: "GET", Description: "Get classification configuration"}, + {Path: "/config/classification", Method: "PUT", Description: "Update classification configuration"}, + {Path: "/config/system-prompts", Method: "GET", Description: "Get system prompt configuration (requires explicit enablement)"}, + {Path: "/config/system-prompts", Method: "PUT", Description: "Update system prompt configuration (requires explicit enablement)"}, +} + +// taskTypeRegistry is a centralized registry of all supported task types +var taskTypeRegistry = []TaskTypeInfo{ + {Name: "intent", Description: "Intent/category classification (default for batch endpoint)"}, + {Name: "pii", Description: "Personally Identifiable Information detection"}, + {Name: "security", Description: "Jailbreak and security threat detection"}, + {Name: "all", Description: "All classification types combined"}, +} + +// OpenAPI 3.0 spec structures + +// OpenAPISpec represents an OpenAPI 3.0 specification +type OpenAPISpec struct { + OpenAPI string `json:"openapi"` + Info OpenAPIInfo `json:"info"` + Servers []OpenAPIServer `json:"servers"` + Paths map[string]OpenAPIPath `json:"paths"` + Components OpenAPIComponents `json:"components,omitempty"` +} + +// OpenAPIInfo contains API metadata +type OpenAPIInfo struct { + Title string `json:"title"` + Description string `json:"description"` + Version string `json:"version"` +} + +// OpenAPIServer describes a server +type OpenAPIServer struct { + URL string `json:"url"` + Description string `json:"description"` +} + +// OpenAPIPath represents operations for a path +type OpenAPIPath struct { + Get *OpenAPIOperation `json:"get,omitempty"` + Post *OpenAPIOperation `json:"post,omitempty"` + Put *OpenAPIOperation `json:"put,omitempty"` + Delete *OpenAPIOperation `json:"delete,omitempty"` +} + +// OpenAPIOperation describes an API operation +type OpenAPIOperation struct { + Summary string `json:"summary"` + Description string `json:"description,omitempty"` + OperationID string `json:"operationId,omitempty"` + Responses map[string]OpenAPIResponse `json:"responses"` + RequestBody *OpenAPIRequestBody `json:"requestBody,omitempty"` +} + +// OpenAPIResponse describes a response +type OpenAPIResponse struct { + Description string `json:"description"` + Content map[string]OpenAPIMedia `json:"content,omitempty"` +} + +// OpenAPIRequestBody describes a request body +type OpenAPIRequestBody struct { + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` + Content map[string]OpenAPIMedia `json:"content"` +} + +// OpenAPIMedia describes media type content +type OpenAPIMedia struct { + Schema *OpenAPISchema `json:"schema,omitempty"` +} + +// OpenAPISchema describes a schema +type OpenAPISchema struct { + Type string `json:"type,omitempty"` + Properties map[string]OpenAPISchema `json:"properties,omitempty"` + Items *OpenAPISchema `json:"items,omitempty"` + Ref string `json:"$ref,omitempty"` +} + +// OpenAPIComponents contains reusable components +type OpenAPIComponents struct { + Schemas map[string]OpenAPISchema `json:"schemas,omitempty"` +} + +// handleAPIOverview handles GET /api/v1 for API discovery +func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, r *http.Request) { + // Build endpoints list from registry, filtering out disabled endpoints + endpoints := make([]EndpointInfo, 0, len(endpointRegistry)) + for _, metadata := range endpointRegistry { + // Filter out system prompt endpoints if they are disabled + if !s.enableSystemPromptAPI && (metadata.Path == "/config/system-prompts") { + continue + } + endpoints = append(endpoints, EndpointInfo{ + Path: metadata.Path, + Method: metadata.Method, + Description: metadata.Description, + }) + } + + response := APIOverviewResponse{ + Service: "Semantic Router Classification API", + Version: "v1", + Description: "API for intent classification, PII detection, and security analysis", + Endpoints: endpoints, + TaskTypes: taskTypeRegistry, + Links: map[string]string{ + "documentation": "https://vllm-project.github.io/semantic-router/", + "openapi_spec": "/openapi.json", + "swagger_ui": "/docs", + "models_info": "/info/models", + "health": "/health", + }, + } + + s.writeJSONResponse(w, http.StatusOK, response) +} + +// generateOpenAPISpec generates an OpenAPI 3.0 specification from the endpoint registry +func (s *ClassificationAPIServer) generateOpenAPISpec() OpenAPISpec { + spec := OpenAPISpec{ + OpenAPI: "3.0.0", + Info: OpenAPIInfo{ + Title: "Semantic Router Classification API", + Description: "API for intent classification, PII detection, and security analysis", + Version: "v1", + }, + Servers: []OpenAPIServer{ + { + URL: "/", + Description: "Classification API Server", + }, + }, + Paths: make(map[string]OpenAPIPath), + } + + // Generate paths from endpoint registry + for _, endpoint := range endpointRegistry { + // Filter out system prompt endpoints if they are disabled + if !s.enableSystemPromptAPI && endpoint.Path == "/config/system-prompts" { + continue + } + + path, ok := spec.Paths[endpoint.Path] + if !ok { + path = OpenAPIPath{} + } + + operation := &OpenAPIOperation{ + Summary: endpoint.Description, + Description: endpoint.Description, + OperationID: fmt.Sprintf("%s_%s", endpoint.Method, endpoint.Path), + Responses: map[string]OpenAPIResponse{ + "200": { + Description: "Successful response", + Content: map[string]OpenAPIMedia{ + "application/json": { + Schema: &OpenAPISchema{ + Type: "object", + }, + }, + }, + }, + "400": { + Description: "Bad request", + Content: map[string]OpenAPIMedia{ + "application/json": { + Schema: &OpenAPISchema{ + Type: "object", + Properties: map[string]OpenAPISchema{ + "error": { + Type: "object", + Properties: map[string]OpenAPISchema{ + "code": {Type: "string"}, + "message": {Type: "string"}, + "timestamp": {Type: "string"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + // Add request body for POST and PUT methods + if endpoint.Method == "POST" || endpoint.Method == "PUT" { + operation.RequestBody = &OpenAPIRequestBody{ + Required: true, + Content: map[string]OpenAPIMedia{ + "application/json": { + Schema: &OpenAPISchema{ + Type: "object", + }, + }, + }, + } + } + + // Map operation to the appropriate method + switch endpoint.Method { + case "GET": + path.Get = operation + case "POST": + path.Post = operation + case "PUT": + path.Put = operation + case "DELETE": + path.Delete = operation + } + + spec.Paths[endpoint.Path] = path + } + + return spec +} + +// handleOpenAPISpec serves the OpenAPI 3.0 specification at /openapi.json +func (s *ClassificationAPIServer) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) { + spec := s.generateOpenAPISpec() + s.writeJSONResponse(w, http.StatusOK, spec) +} + +// handleSwaggerUI serves the Swagger UI at /docs +func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, r *http.Request) { + // Serve a simple HTML page that loads Swagger UI from CDN + html := ` + + + + + Semantic Router API Documentation + + + + +
+ + + + +` + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(html)) +} + // handleIntentClassification handles intent classification requests func (s *ClassificationAPIServer) handleIntentClassification(w http.ResponseWriter, r *http.Request) { var req services.IntentRequest @@ -335,6 +659,13 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite return } + // Validate task_type if provided + if err := validateTaskType(req.TaskType); err != nil { + metrics.RecordBatchClassificationError("unified", "invalid_task_type") + s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", err.Error()) + return + } + // Record the number of texts being processed metrics.RecordBatchClassificationTexts("unified", len(req.Texts)) @@ -622,6 +953,24 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo { } } +// validateTaskType validates the task_type parameter for batch classification +// Returns an error if the task_type is invalid, nil if valid or empty +func validateTaskType(taskType string) error { + // Empty task_type defaults to "intent", so it's valid + if taskType == "" { + return nil + } + + validTaskTypes := []string{"intent", "pii", "security", "all"} + for _, valid := range validTaskTypes { + if taskType == valid { + return nil + } + } + + return fmt.Errorf("invalid task_type '%s'. Supported values: %v", taskType, validTaskTypes) +} + // extractRequestedResults converts unified results to batch format based on task type func (s *ClassificationAPIServer) extractRequestedResults(unifiedResults *services.UnifiedBatchResponse, taskType string, options *ClassificationOptions) []BatchClassificationResult { // Determine the correct batch size based on task type diff --git a/src/semantic-router/pkg/api/server_test.go b/src/semantic-router/pkg/api/server_test.go index 450b3d20..aaf4e005 100644 --- a/src/semantic-router/pkg/api/server_test.go +++ b/src/semantic-router/pkg/api/server_test.go @@ -34,6 +34,59 @@ func TestHandleBatchClassification(t *testing.T) { expectedStatus: http.StatusServiceUnavailable, expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.", }, + { + name: "Invalid task_type - jailbreak", + requestBody: `{ + "texts": ["test text"], + "task_type": "jailbreak" + }`, + expectedStatus: http.StatusBadRequest, + expectedError: "invalid task_type 'jailbreak'. Supported values: [intent pii security all]", + }, + { + name: "Invalid task_type - random", + requestBody: `{ + "texts": ["test text"], + "task_type": "invalid_type" + }`, + expectedStatus: http.StatusBadRequest, + expectedError: "invalid task_type 'invalid_type'. Supported values: [intent pii security all]", + }, + { + name: "Valid task_type - pii", + requestBody: `{ + "texts": ["test text"], + "task_type": "pii" + }`, + expectedStatus: http.StatusServiceUnavailable, + expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.", + }, + { + name: "Valid task_type - security", + requestBody: `{ + "texts": ["test text"], + "task_type": "security" + }`, + expectedStatus: http.StatusServiceUnavailable, + expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.", + }, + { + name: "Valid task_type - all", + requestBody: `{ + "texts": ["test text"], + "task_type": "all" + }`, + expectedStatus: http.StatusServiceUnavailable, + expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.", + }, + { + name: "Empty task_type defaults to intent", + requestBody: `{ + "texts": ["test text"] + }`, + expectedStatus: http.StatusServiceUnavailable, + expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.", + }, { name: "Valid large batch", requestBody: func() string { @@ -731,3 +784,299 @@ func TestSetupRoutesSecurityBehavior(t *testing.T) { }) } } + +// TestAPIOverviewEndpoint tests the API discovery endpoint +func TestAPIOverviewEndpoint(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + } + + req := httptest.NewRequest("GET", "/api/v1", nil) + rr := httptest.NewRecorder() + + apiServer.handleAPIOverview(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + var response APIOverviewResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Verify the response structure + if response.Service == "" { + t.Error("Expected non-empty service name") + } + + if response.Version != "v1" { + t.Errorf("Expected version 'v1', got '%s'", response.Version) + } + + // Check that we have endpoints listed + if len(response.Endpoints) == 0 { + t.Error("Expected at least one endpoint") + } + + // Check that we have task types listed + expectedTaskTypes := map[string]bool{ + "intent": false, + "pii": false, + "security": false, + "all": false, + } + + for _, taskType := range response.TaskTypes { + if _, exists := expectedTaskTypes[taskType.Name]; exists { + expectedTaskTypes[taskType.Name] = true + } + } + + for taskType, found := range expectedTaskTypes { + if !found { + t.Errorf("Expected to find task_type '%s' in response", taskType) + } + } + + // Check that we have links + if len(response.Links) == 0 { + t.Error("Expected at least one link") + } + + // Verify specific endpoints are present + endpointPaths := make(map[string]bool) + for _, endpoint := range response.Endpoints { + endpointPaths[endpoint.Path] = true + } + + requiredPaths := []string{ + "/api/v1/classify/intent", + "/api/v1/classify/pii", + "/api/v1/classify/security", + "/api/v1/classify/batch", + "/health", + } + + for _, path := range requiredPaths { + if !endpointPaths[path] { + t.Errorf("Expected to find endpoint '%s' in response", path) + } + } + + // Verify system prompt endpoints are not included when disabled (default) + if endpointPaths["/config/system-prompts"] { + t.Error("Expected system prompt endpoints to be excluded when enableSystemPromptAPI is false") + } +} + +// TestAPIOverviewEndpointWithSystemPrompts tests API discovery with system prompts enabled +func TestAPIOverviewEndpointWithSystemPrompts(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + enableSystemPromptAPI: true, + } + + req := httptest.NewRequest("GET", "/api/v1", nil) + rr := httptest.NewRecorder() + + apiServer.handleAPIOverview(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + var response APIOverviewResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Verify system prompt endpoints are included when enabled + endpointPaths := make(map[string]bool) + for _, endpoint := range response.Endpoints { + endpointPaths[endpoint.Path] = true + } + + if !endpointPaths["/config/system-prompts"] { + t.Error("Expected system prompt endpoints to be included when enableSystemPromptAPI is true") + } +} + +// TestOpenAPISpecEndpoint tests the OpenAPI specification endpoint +func TestOpenAPISpecEndpoint(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + } + + req := httptest.NewRequest("GET", "/openapi.json", nil) + rr := httptest.NewRecorder() + + apiServer.handleOpenAPISpec(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + // Check Content-Type + contentType := rr.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType) + } + + var spec OpenAPISpec + if err := json.Unmarshal(rr.Body.Bytes(), &spec); err != nil { + t.Fatalf("Failed to unmarshal OpenAPI spec: %v", err) + } + + // Verify the OpenAPI version + if spec.OpenAPI != "3.0.0" { + t.Errorf("Expected OpenAPI version '3.0.0', got '%s'", spec.OpenAPI) + } + + // Verify the info + if spec.Info.Title == "" { + t.Error("Expected non-empty title") + } + + if spec.Info.Version != "v1" { + t.Errorf("Expected version 'v1', got '%s'", spec.Info.Version) + } + + // Verify paths are present + if len(spec.Paths) == 0 { + t.Error("Expected at least one path in OpenAPI spec") + } + + // Check that key endpoints are documented + requiredPaths := []string{ + "/health", + "/api/v1", + "/api/v1/classify/batch", + "/openapi.json", + "/docs", + } + + for _, path := range requiredPaths { + if _, exists := spec.Paths[path]; !exists { + t.Errorf("Expected path '%s' to be in OpenAPI spec", path) + } + } + + // Verify system prompt endpoints are not included when disabled + if _, exists := spec.Paths["/config/system-prompts"]; exists { + t.Error("Expected system prompt endpoints to be excluded from OpenAPI spec when disabled") + } +} + +// TestOpenAPISpecWithSystemPrompts tests OpenAPI spec generation with system prompts enabled +func TestOpenAPISpecWithSystemPrompts(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + enableSystemPromptAPI: true, + } + + req := httptest.NewRequest("GET", "/openapi.json", nil) + rr := httptest.NewRecorder() + + apiServer.handleOpenAPISpec(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + var spec OpenAPISpec + if err := json.Unmarshal(rr.Body.Bytes(), &spec); err != nil { + t.Fatalf("Failed to unmarshal OpenAPI spec: %v", err) + } + + // Verify system prompt endpoints are included when enabled + if _, exists := spec.Paths["/config/system-prompts"]; !exists { + t.Error("Expected system prompt endpoints to be included in OpenAPI spec when enabled") + } +} + +// TestSwaggerUIEndpoint tests the Swagger UI endpoint +func TestSwaggerUIEndpoint(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + } + + req := httptest.NewRequest("GET", "/docs", nil) + rr := httptest.NewRecorder() + + apiServer.handleSwaggerUI(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + // Check Content-Type + contentType := rr.Header().Get("Content-Type") + if contentType != "text/html; charset=utf-8" { + t.Errorf("Expected Content-Type 'text/html; charset=utf-8', got '%s'", contentType) + } + + // Check that the HTML contains Swagger UI references + html := rr.Body.String() + if !bytes.Contains([]byte(html), []byte("swagger-ui")) { + t.Error("Expected HTML to contain 'swagger-ui'") + } + + if !bytes.Contains([]byte(html), []byte("/openapi.json")) { + t.Error("Expected HTML to reference '/openapi.json'") + } + + if !bytes.Contains([]byte(html), []byte("SwaggerUIBundle")) { + t.Error("Expected HTML to contain 'SwaggerUIBundle'") + } +} + +// TestAPIOverviewIncludesNewEndpoints tests that API overview includes new documentation endpoints +func TestAPIOverviewIncludesNewEndpoints(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + } + + req := httptest.NewRequest("GET", "/api/v1", nil) + rr := httptest.NewRecorder() + + apiServer.handleAPIOverview(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200 OK, got %d", rr.Code) + } + + var response APIOverviewResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Verify new documentation endpoints are included + endpointPaths := make(map[string]bool) + for _, endpoint := range response.Endpoints { + endpointPaths[endpoint.Path] = true + } + + if !endpointPaths["/openapi.json"] { + t.Error("Expected '/openapi.json' to be in API overview") + } + + if !endpointPaths["/docs"] { + t.Error("Expected '/docs' to be in API overview") + } + + // Verify links include new documentation endpoints + if response.Links["openapi_spec"] != "/openapi.json" { + t.Error("Expected 'openapi_spec' link to '/openapi.json'") + } + + if response.Links["swagger_ui"] != "/docs" { + t.Error("Expected 'swagger_ui' link to '/docs'") + } +}