From 2f5ba58dfc897ac484dc90fdce54517838b4a8ac Mon Sep 17 00:00:00 2001 From: OneZero-Y Date: Wed, 3 Sep 2025 22:14:55 +0800 Subject: [PATCH 1/3] feat: implement batch classification API Signed-off-by: OneZero-Y --- config/config.yaml | 9 +- src/semantic-router/pkg/api/server.go | 196 ++++++++++++- src/semantic-router/pkg/api/server_test.go | 307 +++++++++++++++++++++ src/semantic-router/pkg/config/config.go | 18 ++ 4 files changed, 528 insertions(+), 2 deletions(-) create mode 100644 src/semantic-router/pkg/api/server_test.go diff --git a/config/config.yaml b/config/config.yaml index fadbf30a..95df80fa 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -237,4 +237,11 @@ categories: - model: phi4 score: 0.2 default_model: mistral-small3.1 -default_reasoning_effort: medium # Default reasoning effort level (low, medium, high) \ No newline at end of file +default_reasoning_effort: medium # Default reasoning effort level (low, medium, high) + +# API Configuration +api: + batch_classification: + max_batch_size: 100 # Maximum number of texts in a single batch + concurrency_threshold: 5 # Switch to concurrent processing when batch size > this value + max_concurrency: 8 # Maximum number of concurrent goroutines \ No newline at end of file diff --git a/src/semantic-router/pkg/api/server.go b/src/semantic-router/pkg/api/server.go index d4635632..f3fc86de 100644 --- a/src/semantic-router/pkg/api/server.go +++ b/src/semantic-router/pkg/api/server.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "runtime" + "sync" "time" "github.com/vllm-project/semantic-router/semantic-router/pkg/config" @@ -46,6 +47,34 @@ type SystemInfo struct { GPUAvailable bool `json:"gpu_available"` } +// BatchClassificationRequest represents a batch classification request +type BatchClassificationRequest struct { + Texts []string `json:"texts"` + Options *ClassificationOptions `json:"options,omitempty"` +} + +// BatchClassificationResponse represents the response from batch classification +type BatchClassificationResponse struct { + Results []services.Classification `json:"results"` + TotalCount int `json:"total_count"` + ProcessingTimeMs int64 `json:"processing_time_ms"` + Statistics Statistics `json:"statistics"` +} + +// Statistics provides batch processing statistics +type Statistics struct { + CategoryDistribution map[string]int `json:"category_distribution"` + AvgConfidence float64 `json:"avg_confidence"` + LowConfidenceCount int `json:"low_confidence_count"` +} + +// ClassificationOptions mirrors services.IntentOptions for API layer +type ClassificationOptions struct { + ReturnProbabilities bool `json:"return_probabilities,omitempty"` + ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"` + IncludeExplanation bool `json:"include_explanation,omitempty"` +} + // StartClassificationAPI starts the Classification API server func StartClassificationAPI(configPath string, port int) error { // Load configuration @@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr } func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) { - s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet") + start := time.Now() + + var req BatchClassificationRequest + if err := s.parseJSONRequest(r, &req); err != nil { + s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error()) + return + } + + // Input validation + if len(req.Texts) == 0 { + s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty") + return + } + + // Get max batch size from config, default to 100 + maxBatchSize := 100 + if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 { + maxBatchSize = s.config.API.BatchClassification.MaxBatchSize + } + + if len(req.Texts) > maxBatchSize { + s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE", + fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize)) + return + } + + // Get concurrency threshold from config, default to 5 + concurrencyThreshold := 5 + if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 { + concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold + } + + // Process texts based on batch size + var results []services.Classification + var err error + + if len(req.Texts) <= concurrencyThreshold { + results, err = s.processSequentially(req.Texts, req.Options) + } else { + results, err = s.processConcurrently(req.Texts, req.Options) + } + + if err != nil { + s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error()) + return + } + + // Calculate statistics + statistics := s.calculateStatistics(results) + + response := BatchClassificationResponse{ + Results: results, + TotalCount: len(req.Texts), + ProcessingTimeMs: time.Since(start).Milliseconds(), + Statistics: statistics, + } + + s.writeJSONResponse(w, http.StatusOK, response) } func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) { @@ -405,3 +491,111 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo { GPUAvailable: false, // TODO: Implement GPU detection } } + +// processSequentially handles small batches with sequential processing +func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) { + results := make([]services.Classification, len(texts)) + for i, text := range texts { + result, err := s.classifySingleText(text, options) + if err != nil { + return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err) + } + results[i] = result + } + return results, nil +} + +// processConcurrently handles large batches with concurrent processing +func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) { + // Get max concurrency from config, default to 8 + maxConcurrency := 8 + if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 { + maxConcurrency = s.config.API.BatchClassification.MaxConcurrency + } + + results := make([]services.Classification, len(texts)) + errors := make([]error, len(texts)) + + semaphore := make(chan struct{}, maxConcurrency) + var wg sync.WaitGroup + + for i, text := range texts { + wg.Add(1) + go func(index int, txt string) { + defer wg.Done() + semaphore <- struct{}{} + defer func() { <-semaphore }() + + result, err := s.classifySingleText(txt, options) + if err != nil { + errors[index] = err + return + } + results[index] = result + }(i, text) + } + + wg.Wait() + + // Check for errors + for i, err := range errors { + if err != nil { + return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err) + } + } + + return results, nil +} + +// classifySingleText processes a single text using existing service +func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) { + // Convert API options to service options + var serviceOptions *services.IntentOptions + if options != nil { + serviceOptions = &services.IntentOptions{ + ReturnProbabilities: options.ReturnProbabilities, + ConfidenceThreshold: options.ConfidenceThreshold, + IncludeExplanation: options.IncludeExplanation, + } + } + + individualReq := services.IntentRequest{ + Text: text, + Options: serviceOptions, + } + + response, err := s.classificationSvc.ClassifyIntent(individualReq) + if err != nil { + return services.Classification{}, err + } + + return response.Classification, nil +} + +// calculateStatistics computes batch processing statistics +func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) Statistics { + categoryDistribution := make(map[string]int) + var totalConfidence float64 + lowConfidenceCount := 0 + + for _, result := range results { + if result.Category != "" { + categoryDistribution[result.Category]++ + } + totalConfidence += result.Confidence + if result.Confidence < 0.7 { + lowConfidenceCount++ + } + } + + avgConfidence := 0.0 + if len(results) > 0 { + avgConfidence = totalConfidence / float64(len(results)) + } + + return Statistics{ + CategoryDistribution: categoryDistribution, + AvgConfidence: avgConfidence, + LowConfidenceCount: lowConfidenceCount, + } +} diff --git a/src/semantic-router/pkg/api/server_test.go b/src/semantic-router/pkg/api/server_test.go new file mode 100644 index 00000000..5f65d3e8 --- /dev/null +++ b/src/semantic-router/pkg/api/server_test.go @@ -0,0 +1,307 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "math" + "net/http" + "net/http/httptest" + "testing" + + "github.com/vllm-project/semantic-router/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/semantic-router/pkg/services" +) + +func TestHandleBatchClassification(t *testing.T) { + // Create a test server with placeholder service + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: &config.RouterConfig{}, + } + + tests := []struct { + name string + requestBody string + expectedStatus int + expectedError string + }{ + { + name: "Valid small batch", + requestBody: `{ + "texts": ["solve math equation", "write business plan", "chemistry experiment"] + }`, + expectedStatus: http.StatusOK, + }, + { + name: "Valid large batch", + requestBody: `{ + "texts": [ + "solve differential equation", + "business strategy analysis", + "chemistry reaction", + "physics calculation", + "market research", + "mathematical modeling", + "financial planning", + "scientific experiment" + ] + }`, + expectedStatus: http.StatusOK, + }, + { + name: "Valid batch with options", + requestBody: `{ + "texts": ["solve math equation", "write business plan"], + "options": {"return_probabilities": true} + }`, + expectedStatus: http.StatusOK, + }, + { + name: "Empty texts array", + requestBody: `{ + "texts": [] + }`, + expectedStatus: http.StatusBadRequest, + expectedError: "texts array cannot be empty", + }, + { + name: "Missing texts field", + requestBody: `{}`, + expectedStatus: http.StatusBadRequest, + expectedError: "texts array cannot be empty", + }, + { + name: "Batch too large", + requestBody: func() string { + texts := make([]string, 101) + for i := range texts { + texts[i] = fmt.Sprintf("test query %d", i) + } + data := map[string]interface{}{"texts": texts} + b, _ := json.Marshal(data) + return string(b) + }(), + expectedStatus: http.StatusBadRequest, + expectedError: "batch size cannot exceed 100 texts", + }, + { + name: "Invalid JSON", + requestBody: `{"texts": [invalid json`, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/api/v1/classify/batch", bytes.NewBufferString(tt.requestBody)) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + + apiServer.handleBatchClassification(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedStatus == http.StatusOK { + // For successful requests, check response structure + var response BatchClassificationResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Errorf("Failed to unmarshal response: %v", err) + } + + // Validate response structure + if response.TotalCount == 0 { + t.Error("Expected non-zero total count") + } + if len(response.Results) == 0 { + t.Error("Expected non-empty results") + } + if response.ProcessingTimeMs < 0 { + t.Error("Expected non-negative processing time") + } + + // Check statistics + if response.Statistics.AvgConfidence < 0 || response.Statistics.AvgConfidence > 1 { + t.Error("Expected confidence between 0 and 1") + } + } else if tt.expectedError != "" { + // For error responses, check error message + var errorResponse map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &errorResponse); err != nil { + t.Errorf("Failed to unmarshal error response: %v", err) + } + + if errorData, ok := errorResponse["error"].(map[string]interface{}); ok { + if message, ok := errorData["message"].(string); ok { + if message != tt.expectedError { + t.Errorf("Expected error message '%s', got '%s'", tt.expectedError, message) + } + } + } + } + }) + } +} + +func TestCalculateStatistics(t *testing.T) { + apiServer := &ClassificationAPIServer{} + + tests := []struct { + name string + results []services.Classification + expected Statistics + }{ + { + name: "Mixed categories", + results: []services.Classification{ + {Category: "math", Confidence: 0.9}, + {Category: "math", Confidence: 0.8}, + {Category: "business", Confidence: 0.6}, + {Category: "science", Confidence: 0.5}, + }, + expected: Statistics{ + CategoryDistribution: map[string]int{ + "math": 2, + "business": 1, + "science": 1, + }, + AvgConfidence: 0.7, + LowConfidenceCount: 2, // 0.6 and 0.5 are below 0.7 + }, + }, + { + name: "Empty results", + results: []services.Classification{}, + expected: Statistics{ + CategoryDistribution: map[string]int{}, + AvgConfidence: 0.0, + LowConfidenceCount: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stats := apiServer.calculateStatistics(tt.results) + + if math.Abs(stats.AvgConfidence-tt.expected.AvgConfidence) > 0.001 { + t.Errorf("Expected avg confidence %.3f, got %.3f", tt.expected.AvgConfidence, stats.AvgConfidence) + } + + if stats.LowConfidenceCount != tt.expected.LowConfidenceCount { + t.Errorf("Expected low confidence count %d, got %d", tt.expected.LowConfidenceCount, stats.LowConfidenceCount) + } + + for category, expectedCount := range tt.expected.CategoryDistribution { + if actualCount, exists := stats.CategoryDistribution[category]; !exists || actualCount != expectedCount { + t.Errorf("Expected category %s count %d, got %d", category, expectedCount, actualCount) + } + } + }) + } +} + +func TestBatchClassificationConfiguration(t *testing.T) { + tests := []struct { + name string + config *config.RouterConfig + requestBody string + expectedStatus int + expectedError string + }{ + { + name: "Custom max batch size", + config: &config.RouterConfig{ + API: config.APIConfig{ + BatchClassification: struct { + MaxBatchSize int `yaml:"max_batch_size,omitempty"` + ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"` + MaxConcurrency int `yaml:"max_concurrency,omitempty"` + }{ + MaxBatchSize: 3, // Custom small limit + ConcurrencyThreshold: 2, + MaxConcurrency: 4, + }, + }, + }, + requestBody: `{ + "texts": ["text1", "text2", "text3", "text4"] + }`, + expectedStatus: http.StatusBadRequest, + expectedError: "batch size cannot exceed 3 texts", + }, + { + name: "Default config when config is nil", + config: nil, + requestBody: func() string { + texts := make([]string, 101) + for i := range texts { + texts[i] = fmt.Sprintf("test query %d", i) + } + data := map[string]interface{}{"texts": texts} + b, _ := json.Marshal(data) + return string(b) + }(), + expectedStatus: http.StatusBadRequest, + expectedError: "batch size cannot exceed 100 texts", // Default limit + }, + { + name: "Valid request within custom limits", + config: &config.RouterConfig{ + API: config.APIConfig{ + BatchClassification: struct { + MaxBatchSize int `yaml:"max_batch_size,omitempty"` + ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"` + MaxConcurrency int `yaml:"max_concurrency,omitempty"` + }{ + MaxBatchSize: 10, + ConcurrencyThreshold: 3, + MaxConcurrency: 2, + }, + }, + }, + requestBody: `{ + "texts": ["text1", "text2"] + }`, + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: tt.config, + } + + req := httptest.NewRequest("POST", "/api/v1/classify/batch", bytes.NewBufferString(tt.requestBody)) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + + apiServer.handleBatchClassification(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedError != "" { + var errorResponse map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &errorResponse); err != nil { + t.Errorf("Failed to unmarshal error response: %v", err) + } + + if errorData, ok := errorResponse["error"].(map[string]interface{}); ok { + if message, ok := errorData["message"].(string); ok { + if message != tt.expectedError { + t.Errorf("Expected error message '%s', got '%s'", tt.expectedError, message) + } + } + } + } + }) + } +} diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index fa7cc7f5..d95d921f 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -61,6 +61,9 @@ type RouterConfig struct { // vLLM endpoints configuration for multiple backend support VLLMEndpoints []VLLMEndpoint `yaml:"vllm_endpoints"` + + // API configuration for classification endpoints + API APIConfig `yaml:"api"` } // SemanticCacheConfig represents configuration for the semantic cache @@ -79,6 +82,21 @@ type SemanticCacheConfig struct { TTLSeconds int `yaml:"ttl_seconds,omitempty"` } +// APIConfig represents configuration for API endpoints +type APIConfig struct { + // Batch classification configuration + BatchClassification struct { + // Maximum number of texts allowed in a single batch request + MaxBatchSize int `yaml:"max_batch_size,omitempty"` + + // Threshold for switching from sequential to concurrent processing + ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"` + + // Maximum number of concurrent goroutines for batch processing + MaxConcurrency int `yaml:"max_concurrency,omitempty"` + } `yaml:"batch_classification"` +} + // PromptGuardConfig represents configuration for the prompt guard jailbreak detection type PromptGuardConfig struct { // Enable prompt guard jailbreak detection From e393dfc507ed8958210a5bf951565f675112dad1 Mon Sep 17 00:00:00 2001 From: OneZero-Y Date: Thu, 4 Sep 2025 11:10:39 +0800 Subject: [PATCH 2/3] update the classify API docs Signed-off-by: OneZero-Y --- src/semantic-router/pkg/api/server.go | 19 +-- src/semantic-router/pkg/api/server_test.go | 8 +- website/docs/api/classification.md | 154 ++++++++++++++++----- 3 files changed, 132 insertions(+), 49 deletions(-) diff --git a/src/semantic-router/pkg/api/server.go b/src/semantic-router/pkg/api/server.go index f3fc86de..5520b2ba 100644 --- a/src/semantic-router/pkg/api/server.go +++ b/src/semantic-router/pkg/api/server.go @@ -55,14 +55,14 @@ type BatchClassificationRequest struct { // BatchClassificationResponse represents the response from batch classification type BatchClassificationResponse struct { - Results []services.Classification `json:"results"` - TotalCount int `json:"total_count"` - ProcessingTimeMs int64 `json:"processing_time_ms"` - Statistics Statistics `json:"statistics"` + Results []services.Classification `json:"results"` + TotalCount int `json:"total_count"` + ProcessingTimeMs int64 `json:"processing_time_ms"` + Statistics CategoryClassificationStatistics `json:"statistics"` } -// Statistics provides batch processing statistics -type Statistics struct { +// CategoryClassificationStatistics provides batch processing statistics +type CategoryClassificationStatistics struct { CategoryDistribution map[string]int `json:"category_distribution"` AvgConfidence float64 `json:"avg_confidence"` LowConfidenceCount int `json:"low_confidence_count"` @@ -526,6 +526,9 @@ func (s *ClassificationAPIServer) processConcurrently(texts []string, options *C semaphore <- struct{}{} defer func() { <-semaphore }() + // TODO: Refactor candle-binding to support batch mode for better performance + // This would allow processing multiple texts in a single model inference call + // instead of individual calls, significantly improving throughput result, err := s.classifySingleText(txt, options) if err != nil { errors[index] = err @@ -573,7 +576,7 @@ func (s *ClassificationAPIServer) classifySingleText(text string, options *Class } // calculateStatistics computes batch processing statistics -func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) Statistics { +func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) CategoryClassificationStatistics { categoryDistribution := make(map[string]int) var totalConfidence float64 lowConfidenceCount := 0 @@ -593,7 +596,7 @@ func (s *ClassificationAPIServer) calculateStatistics(results []services.Classif avgConfidence = totalConfidence / float64(len(results)) } - return Statistics{ + return CategoryClassificationStatistics{ CategoryDistribution: categoryDistribution, AvgConfidence: avgConfidence, LowConfidenceCount: lowConfidenceCount, diff --git a/src/semantic-router/pkg/api/server_test.go b/src/semantic-router/pkg/api/server_test.go index 5f65d3e8..bea2f817 100644 --- a/src/semantic-router/pkg/api/server_test.go +++ b/src/semantic-router/pkg/api/server_test.go @@ -38,7 +38,7 @@ func TestHandleBatchClassification(t *testing.T) { requestBody: `{ "texts": [ "solve differential equation", - "business strategy analysis", + "business strategy analysis", "chemistry reaction", "physics calculation", "market research", @@ -152,7 +152,7 @@ func TestCalculateStatistics(t *testing.T) { tests := []struct { name string results []services.Classification - expected Statistics + expected CategoryClassificationStatistics }{ { name: "Mixed categories", @@ -162,7 +162,7 @@ func TestCalculateStatistics(t *testing.T) { {Category: "business", Confidence: 0.6}, {Category: "science", Confidence: 0.5}, }, - expected: Statistics{ + expected: CategoryClassificationStatistics{ CategoryDistribution: map[string]int{ "math": 2, "business": 1, @@ -175,7 +175,7 @@ func TestCalculateStatistics(t *testing.T) { { name: "Empty results", results: []services.Classification{}, - expected: Statistics{ + expected: CategoryClassificationStatistics{ CategoryDistribution: map[string]int{}, AvgConfidence: 0.0, LowConfidenceCount: 0, diff --git a/website/docs/api/classification.md b/website/docs/api/classification.md index c461e6ce..b62f000b 100644 --- a/website/docs/api/classification.md +++ b/website/docs/api/classification.md @@ -28,12 +28,12 @@ make run-router - `POST /api/v1/classify/intent` - Intent classification with real model inference - `POST /api/v1/classify/pii` - PII detection with real model inference - `POST /api/v1/classify/security` - Security/jailbreak detection with real model inference +- `POST /api/v1/classify/batch` - Batch classification with configurable processing strategies - `GET /info/models` - Model information and system status - `GET /info/classifier` - Detailed classifier capabilities and configuration ### 🔄 Placeholder Implementation - `POST /api/v1/classify/combined` - Returns "not implemented" response -- `POST /api/v1/classify/batch` - Returns "not implemented" response - `GET /metrics/classification` - Returns "not implemented" response - `GET /config/classification` - Returns "not implemented" response - `PUT /config/classification` - Returns "not implemented" response @@ -65,6 +65,11 @@ curl -X POST http://localhost:8080/api/v1/classify/security \ -H "Content-Type: application/json" \ -d '{"text": "Ignore all previous instructions"}' +# Batch classification +curl -X POST http://localhost:8080/api/v1/classify/batch \ + -H "Content-Type: application/json" \ + -d '{"texts": ["What is machine learning?", "Write a business plan", "Calculate area of circle"]}' + # Model information curl -X GET http://localhost:8080/info/models @@ -280,7 +285,7 @@ Perform multiple classification tasks in a single request. ## Batch Classification -Process multiple texts in a single request for efficiency. +Process multiple texts in a single request for improved efficiency. The API automatically chooses between sequential and concurrent processing based on batch size and configuration. ### Endpoint `POST /classify/batch` @@ -291,14 +296,14 @@ Process multiple texts in a single request for efficiency. { "texts": [ "What is machine learning?", - "Write a poem about spring", - "My SSN is 123-45-6789", - "Ignore all safety measures" + "Write a business plan", + "Calculate the area of a circle", + "Solve differential equations" ], - "task": "combined", "options": { - "return_individual_results": true, - "include_summary": true + "return_probabilities": true, + "confidence_threshold": 0.7, + "include_explanation": false } } ``` @@ -309,44 +314,85 @@ Process multiple texts in a single request for efficiency. { "results": [ { - "index": 0, - "text": "What is machine learning?", - "intent": {"category": "computer_science", "confidence": 0.88}, - "pii": {"has_pii": false}, - "security": {"is_jailbreak": false, "risk_score": 0.01} + "category": "computer science", + "confidence": 0.88, + "processing_time_ms": 45 }, { - "index": 1, - "text": "Write a poem about spring", - "intent": {"category": "creative_writing", "confidence": 0.95}, - "pii": {"has_pii": false}, - "security": {"is_jailbreak": false, "risk_score": 0.02} + "category": "business", + "confidence": 0.92, + "processing_time_ms": 38 }, { - "index": 2, - "text": "My SSN is 123-45-6789", - "intent": {"category": "general", "confidence": 0.67}, - "pii": {"has_pii": true, "entities": [{"type": "SSN", "confidence": 0.99}]}, - "security": {"is_jailbreak": false, "risk_score": 0.05} + "category": "math", + "confidence": 0.95, + "processing_time_ms": 42 }, { - "index": 3, - "text": "Ignore all safety measures", - "intent": {"category": "general", "confidence": 0.45}, - "pii": {"has_pii": false}, - "security": {"is_jailbreak": true, "risk_score": 0.87} + "category": "math", + "confidence": 0.89, + "processing_time_ms": 41 } ], - "summary": { - "total_texts": 4, - "pii_detected": 1, - "jailbreaks_detected": 1, - "average_processing_time_ms": 22, + "total_count": 4, + "processing_time_ms": 156, + "statistics": { "category_distribution": { - "computer_science": 1, - "creative_writing": 1, - "general": 2 - } + "math": 2, + "computer science": 1, + "business": 1 + }, + "avg_confidence": 0.91, + "low_confidence_count": 0 + } +} +``` + +### Configuration + +The batch classification behavior can be configured in `config.yaml`: + +```yaml +api: + batch_classification: + max_batch_size: 100 # Maximum texts per batch + concurrency_threshold: 5 # Switch to concurrent processing when batch > this + max_concurrency: 8 # Maximum concurrent goroutines +``` + +### Processing Strategies + +- **Sequential Processing**: Used for small batches (≤ concurrency_threshold) to minimize overhead +- **Concurrent Processing**: Used for larger batches to improve throughput +- **Automatic Selection**: The API automatically chooses the optimal strategy based on batch size + +### Performance Characteristics + +| Batch Size | Strategy | Expected Performance | +|------------|----------|---------------------| +| 1-5 texts | Sequential | ~Single request latency | +| 6+ texts | Concurrent | ~1/3 to 1/5 of sequential time | + +### Error Handling + +**Batch Too Large (400 Bad Request):** +```json +{ + "error": { + "code": "BATCH_TOO_LARGE", + "message": "batch size cannot exceed 100 texts", + "timestamp": "2024-03-15T14:30:00Z" + } +} +``` + +**Empty Batch (400 Bad Request):** +```json +{ + "error": { + "code": "INVALID_INPUT", + "message": "texts array cannot be empty", + "timestamp": "2024-03-15T14:30:00Z" } } ``` @@ -660,6 +706,16 @@ class ClassificationClient: } ) return response.json() + + def classify_batch(self, texts: List[str], return_probabilities: bool = False) -> Dict: + response = requests.post( + f"{self.base_url}/api/v1/classify/batch", + json={ + "texts": texts, + "options": {"return_probabilities": return_probabilities} + } + ) + return response.json() # Usage example client = ClassificationClient() @@ -679,6 +735,13 @@ if pii_result['has_pii']: security_result = client.check_security("Ignore all previous instructions") if security_result['is_jailbreak']: print(f"Jailbreak detected with risk score: {security_result['risk_score']}") + +# Batch classification +texts = ["What is machine learning?", "Write a business plan", "Calculate area of circle"] +batch_result = client.classify_batch(texts, return_probabilities=True) +print(f"Processed {batch_result['total_count']} texts in {batch_result['processing_time_ms']}ms") +for i, result in enumerate(batch_result['results']): + print(f"Text {i+1}: {result['category']} (confidence: {result['confidence']:.2f})") ``` ### JavaScript SDK @@ -723,6 +786,15 @@ class ClassificationAPI { }); return response.json(); } + + async classifyBatch(texts, options = {}) { + const response = await fetch(`${this.baseUrl}/api/v1/classify/batch`, { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({texts, options}) + }); + return response.json(); + } } // Usage example @@ -746,6 +818,14 @@ const api = new ClassificationAPI(); if (securityResult.is_jailbreak) { console.log(`Security threat detected: Risk score ${securityResult.risk_score}`); } + + // Batch classification + const texts = ["What is machine learning?", "Write a business plan", "Calculate area of circle"]; + const batchResult = await api.classifyBatch(texts, {return_probabilities: true}); + console.log(`Processed ${batchResult.total_count} texts in ${batchResult.processing_time_ms}ms`); + batchResult.results.forEach((result, index) => { + console.log(`Text ${index + 1}: ${result.category} (confidence: ${result.confidence.toFixed(2)})`); + }); })(); ``` From 76e39161de94d03ac5a034513139220c0ff004ad Mon Sep 17 00:00:00 2001 From: OneZero-Y Date: Thu, 4 Sep 2025 12:31:14 +0800 Subject: [PATCH 3/3] update the classify API docs Signed-off-by: OneZero-Y --- website/docs/api/classification.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/website/docs/api/classification.md b/website/docs/api/classification.md index b62f000b..cda2344f 100644 --- a/website/docs/api/classification.md +++ b/website/docs/api/classification.md @@ -366,13 +366,6 @@ api: - **Concurrent Processing**: Used for larger batches to improve throughput - **Automatic Selection**: The API automatically chooses the optimal strategy based on batch size -### Performance Characteristics - -| Batch Size | Strategy | Expected Performance | -|------------|----------|---------------------| -| 1-5 texts | Sequential | ~Single request latency | -| 6+ texts | Concurrent | ~1/3 to 1/5 of sequential time | - ### Error Handling **Batch Too Large (400 Bad Request):**