Skip to content

Commit 9cd696b

Browse files
Copilotrootfs
andcommitted
Add task_type validation and API discovery endpoint
- Add validateTaskType helper function to validate task_type parameter - Reject invalid task_type values with 400 error and helpful message - Add GET /api/v1 endpoint for API discovery - Return comprehensive API overview with endpoints, task_types, and links - Add tests for invalid task_type values (jailbreak, invalid_type) - Add tests for valid task_types (intent, pii, security, all) - Add test for API overview endpoint Co-authored-by: rootfs <[email protected]>
1 parent da3fd5f commit 9cd696b

File tree

2 files changed

+261
-0
lines changed

2 files changed

+261
-0
lines changed

src/semantic-router/pkg/api/server.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
184184
// Health check endpoint
185185
mux.HandleFunc("GET /health", s.handleHealth)
186186

187+
// API discovery endpoint
188+
mux.HandleFunc("GET /api/v1", s.handleAPIOverview)
189+
187190
// Classification endpoints
188191
mux.HandleFunc("POST /api/v1/classify/intent", s.handleIntentClassification)
189192
mux.HandleFunc("POST /api/v1/classify/pii", s.handlePIIDetection)
@@ -224,6 +227,105 @@ func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, r *http.Re
224227
w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`))
225228
}
226229

230+
// APIOverviewResponse represents the response for GET /api/v1
231+
type APIOverviewResponse struct {
232+
Service string `json:"service"`
233+
Version string `json:"version"`
234+
Description string `json:"description"`
235+
Endpoints []EndpointInfo `json:"endpoints"`
236+
TaskTypes []TaskTypeInfo `json:"task_types"`
237+
Links map[string]string `json:"links"`
238+
}
239+
240+
// EndpointInfo represents information about an API endpoint
241+
type EndpointInfo struct {
242+
Path string `json:"path"`
243+
Method string `json:"method"`
244+
Description string `json:"description"`
245+
}
246+
247+
// TaskTypeInfo represents information about a task type
248+
type TaskTypeInfo struct {
249+
Name string `json:"name"`
250+
Description string `json:"description"`
251+
}
252+
253+
// handleAPIOverview handles GET /api/v1 for API discovery
254+
func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, r *http.Request) {
255+
response := APIOverviewResponse{
256+
Service: "Semantic Router Classification API",
257+
Version: "v1",
258+
Description: "API for intent classification, PII detection, and security analysis",
259+
Endpoints: []EndpointInfo{
260+
{
261+
Path: "/api/v1/classify/intent",
262+
Method: "POST",
263+
Description: "Classify user queries into routing categories",
264+
},
265+
{
266+
Path: "/api/v1/classify/pii",
267+
Method: "POST",
268+
Description: "Detect personally identifiable information in text",
269+
},
270+
{
271+
Path: "/api/v1/classify/security",
272+
Method: "POST",
273+
Description: "Detect jailbreak attempts and security threats",
274+
},
275+
{
276+
Path: "/api/v1/classify/combined",
277+
Method: "POST",
278+
Description: "Perform combined classification (intent, PII, and security)",
279+
},
280+
{
281+
Path: "/api/v1/classify/batch",
282+
Method: "POST",
283+
Description: "Batch classification with configurable task_type parameter",
284+
},
285+
{
286+
Path: "/health",
287+
Method: "GET",
288+
Description: "Health check endpoint",
289+
},
290+
{
291+
Path: "/info/models",
292+
Method: "GET",
293+
Description: "Get information about loaded models",
294+
},
295+
{
296+
Path: "/v1/models",
297+
Method: "GET",
298+
Description: "OpenAI-compatible model listing",
299+
},
300+
},
301+
TaskTypes: []TaskTypeInfo{
302+
{
303+
Name: "intent",
304+
Description: "Intent/category classification (default for batch endpoint)",
305+
},
306+
{
307+
Name: "pii",
308+
Description: "Personally Identifiable Information detection",
309+
},
310+
{
311+
Name: "security",
312+
Description: "Jailbreak and security threat detection",
313+
},
314+
{
315+
Name: "all",
316+
Description: "All classification types combined",
317+
},
318+
},
319+
Links: map[string]string{
320+
"documentation": "https://vllm-project.github.io/semantic-router/",
321+
"models_info": "/info/models",
322+
"health": "/health",
323+
},
324+
}
325+
326+
s.writeJSONResponse(w, http.StatusOK, response)
327+
}
328+
227329
// handleIntentClassification handles intent classification requests
228330
func (s *ClassificationAPIServer) handleIntentClassification(w http.ResponseWriter, r *http.Request) {
229331
var req services.IntentRequest
@@ -335,6 +437,13 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
335437
return
336438
}
337439

440+
// Validate task_type if provided
441+
if err := validateTaskType(req.TaskType); err != nil {
442+
metrics.RecordBatchClassificationError("unified", "invalid_task_type")
443+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", err.Error())
444+
return
445+
}
446+
338447
// Record the number of texts being processed
339448
metrics.RecordBatchClassificationTexts("unified", len(req.Texts))
340449

@@ -622,6 +731,24 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {
622731
}
623732
}
624733

734+
// validateTaskType validates the task_type parameter for batch classification
735+
// Returns an error if the task_type is invalid, nil if valid or empty
736+
func validateTaskType(taskType string) error {
737+
// Empty task_type defaults to "intent", so it's valid
738+
if taskType == "" {
739+
return nil
740+
}
741+
742+
validTaskTypes := []string{"intent", "pii", "security", "all"}
743+
for _, valid := range validTaskTypes {
744+
if taskType == valid {
745+
return nil
746+
}
747+
}
748+
749+
return fmt.Errorf("invalid task_type '%s'. Supported values: %v", taskType, validTaskTypes)
750+
}
751+
625752
// extractRequestedResults converts unified results to batch format based on task type
626753
func (s *ClassificationAPIServer) extractRequestedResults(unifiedResults *services.UnifiedBatchResponse, taskType string, options *ClassificationOptions) []BatchClassificationResult {
627754
// Determine the correct batch size based on task type

src/semantic-router/pkg/api/server_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,59 @@ func TestHandleBatchClassification(t *testing.T) {
3434
expectedStatus: http.StatusServiceUnavailable,
3535
expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.",
3636
},
37+
{
38+
name: "Invalid task_type - jailbreak",
39+
requestBody: `{
40+
"texts": ["test text"],
41+
"task_type": "jailbreak"
42+
}`,
43+
expectedStatus: http.StatusBadRequest,
44+
expectedError: "invalid task_type 'jailbreak'. Supported values: [intent pii security all]",
45+
},
46+
{
47+
name: "Invalid task_type - random",
48+
requestBody: `{
49+
"texts": ["test text"],
50+
"task_type": "invalid_type"
51+
}`,
52+
expectedStatus: http.StatusBadRequest,
53+
expectedError: "invalid task_type 'invalid_type'. Supported values: [intent pii security all]",
54+
},
55+
{
56+
name: "Valid task_type - pii",
57+
requestBody: `{
58+
"texts": ["test text"],
59+
"task_type": "pii"
60+
}`,
61+
expectedStatus: http.StatusServiceUnavailable,
62+
expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.",
63+
},
64+
{
65+
name: "Valid task_type - security",
66+
requestBody: `{
67+
"texts": ["test text"],
68+
"task_type": "security"
69+
}`,
70+
expectedStatus: http.StatusServiceUnavailable,
71+
expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.",
72+
},
73+
{
74+
name: "Valid task_type - all",
75+
requestBody: `{
76+
"texts": ["test text"],
77+
"task_type": "all"
78+
}`,
79+
expectedStatus: http.StatusServiceUnavailable,
80+
expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.",
81+
},
82+
{
83+
name: "Empty task_type defaults to intent",
84+
requestBody: `{
85+
"texts": ["test text"]
86+
}`,
87+
expectedStatus: http.StatusServiceUnavailable,
88+
expectedError: "Batch classification requires unified classifier. Please ensure models are available in ./models/ directory.",
89+
},
3790
{
3891
name: "Valid large batch",
3992
requestBody: func() string {
@@ -731,3 +784,84 @@ func TestSetupRoutesSecurityBehavior(t *testing.T) {
731784
})
732785
}
733786
}
787+
788+
// TestAPIOverviewEndpoint tests the API discovery endpoint
789+
func TestAPIOverviewEndpoint(t *testing.T) {
790+
apiServer := &ClassificationAPIServer{
791+
classificationSvc: services.NewPlaceholderClassificationService(),
792+
config: &config.RouterConfig{},
793+
}
794+
795+
req := httptest.NewRequest("GET", "/api/v1", nil)
796+
rr := httptest.NewRecorder()
797+
798+
apiServer.handleAPIOverview(rr, req)
799+
800+
if rr.Code != http.StatusOK {
801+
t.Fatalf("Expected 200 OK, got %d", rr.Code)
802+
}
803+
804+
var response APIOverviewResponse
805+
if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
806+
t.Fatalf("Failed to unmarshal response: %v", err)
807+
}
808+
809+
// Verify the response structure
810+
if response.Service == "" {
811+
t.Error("Expected non-empty service name")
812+
}
813+
814+
if response.Version != "v1" {
815+
t.Errorf("Expected version 'v1', got '%s'", response.Version)
816+
}
817+
818+
// Check that we have endpoints listed
819+
if len(response.Endpoints) == 0 {
820+
t.Error("Expected at least one endpoint")
821+
}
822+
823+
// Check that we have task types listed
824+
expectedTaskTypes := map[string]bool{
825+
"intent": false,
826+
"pii": false,
827+
"security": false,
828+
"all": false,
829+
}
830+
831+
for _, taskType := range response.TaskTypes {
832+
if _, exists := expectedTaskTypes[taskType.Name]; exists {
833+
expectedTaskTypes[taskType.Name] = true
834+
}
835+
}
836+
837+
for taskType, found := range expectedTaskTypes {
838+
if !found {
839+
t.Errorf("Expected to find task_type '%s' in response", taskType)
840+
}
841+
}
842+
843+
// Check that we have links
844+
if len(response.Links) == 0 {
845+
t.Error("Expected at least one link")
846+
}
847+
848+
// Verify specific endpoints are present
849+
endpointPaths := make(map[string]bool)
850+
for _, endpoint := range response.Endpoints {
851+
endpointPaths[endpoint.Path] = true
852+
}
853+
854+
requiredPaths := []string{
855+
"/api/v1/classify/intent",
856+
"/api/v1/classify/pii",
857+
"/api/v1/classify/security",
858+
"/api/v1/classify/batch",
859+
"/health",
860+
}
861+
862+
for _, path := range requiredPaths {
863+
if !endpointPaths[path] {
864+
t.Errorf("Expected to find endpoint '%s' in response", path)
865+
}
866+
}
867+
}

0 commit comments

Comments
 (0)