Skip to content

Commit 3303e48

Browse files
committed
feat: implement batch classification API
Signed-off-by: OneZero-Y <[email protected]>
1 parent 3f32ad4 commit 3303e48

File tree

4 files changed

+528
-2
lines changed

4 files changed

+528
-2
lines changed

config/config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,11 @@ categories:
237237
- model: phi4
238238
score: 0.2
239239
default_model: mistral-small3.1
240-
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)
240+
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)
241+
242+
# API Configuration
243+
api:
244+
batch_classification:
245+
max_batch_size: 100 # Maximum number of texts in a single batch
246+
concurrency_threshold: 5 # Switch to concurrent processing when batch size > this value
247+
max_concurrency: 8 # Maximum number of concurrent goroutines

src/semantic-router/pkg/api/server.go

Lines changed: 195 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log"
88
"net/http"
99
"runtime"
10+
"sync"
1011
"time"
1112

1213
"github.com/vllm-project/semantic-router/semantic-router/pkg/config"
@@ -46,6 +47,34 @@ type SystemInfo struct {
4647
GPUAvailable bool `json:"gpu_available"`
4748
}
4849

50+
// BatchClassificationRequest represents a batch classification request
51+
type BatchClassificationRequest struct {
52+
Texts []string `json:"texts"`
53+
Options *ClassificationOptions `json:"options,omitempty"`
54+
}
55+
56+
// BatchClassificationResponse represents the response from batch classification
57+
type BatchClassificationResponse struct {
58+
Results []services.Classification `json:"results"`
59+
TotalCount int `json:"total_count"`
60+
ProcessingTimeMs int64 `json:"processing_time_ms"`
61+
Statistics Statistics `json:"statistics"`
62+
}
63+
64+
// Statistics provides batch processing statistics
65+
type Statistics struct {
66+
CategoryDistribution map[string]int `json:"category_distribution"`
67+
AvgConfidence float64 `json:"avg_confidence"`
68+
LowConfidenceCount int `json:"low_confidence_count"`
69+
}
70+
71+
// ClassificationOptions mirrors services.IntentOptions for API layer
72+
type ClassificationOptions struct {
73+
ReturnProbabilities bool `json:"return_probabilities,omitempty"`
74+
ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"`
75+
IncludeExplanation bool `json:"include_explanation,omitempty"`
76+
}
77+
4978
// StartClassificationAPI starts the Classification API server
5079
func StartClassificationAPI(configPath string, port int) error {
5180
// Load configuration
@@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr
192221
}
193222

194223
func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) {
195-
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet")
224+
start := time.Now()
225+
226+
var req BatchClassificationRequest
227+
if err := s.parseJSONRequest(r, &req); err != nil {
228+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
229+
return
230+
}
231+
232+
// Input validation
233+
if len(req.Texts) == 0 {
234+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")
235+
return
236+
}
237+
238+
// Get max batch size from config, default to 100
239+
maxBatchSize := 100
240+
if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 {
241+
maxBatchSize = s.config.API.BatchClassification.MaxBatchSize
242+
}
243+
244+
if len(req.Texts) > maxBatchSize {
245+
s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE",
246+
fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize))
247+
return
248+
}
249+
250+
// Get concurrency threshold from config, default to 5
251+
concurrencyThreshold := 5
252+
if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 {
253+
concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold
254+
}
255+
256+
// Process texts based on batch size
257+
var results []services.Classification
258+
var err error
259+
260+
if len(req.Texts) <= concurrencyThreshold {
261+
results, err = s.processSequentially(req.Texts, req.Options)
262+
} else {
263+
results, err = s.processConcurrently(req.Texts, req.Options)
264+
}
265+
266+
if err != nil {
267+
s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error())
268+
return
269+
}
270+
271+
// Calculate statistics
272+
statistics := s.calculateStatistics(results)
273+
274+
response := BatchClassificationResponse{
275+
Results: results,
276+
TotalCount: len(req.Texts),
277+
ProcessingTimeMs: time.Since(start).Milliseconds(),
278+
Statistics: statistics,
279+
}
280+
281+
s.writeJSONResponse(w, http.StatusOK, response)
196282
}
197283

198284
func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {
@@ -405,3 +491,111 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {
405491
GPUAvailable: false, // TODO: Implement GPU detection
406492
}
407493
}
494+
495+
// processSequentially handles small batches with sequential processing
496+
func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
497+
results := make([]services.Classification, len(texts))
498+
for i, text := range texts {
499+
result, err := s.classifySingleText(text, options)
500+
if err != nil {
501+
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
502+
}
503+
results[i] = result
504+
}
505+
return results, nil
506+
}
507+
508+
// processConcurrently handles large batches with concurrent processing
509+
func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
510+
// Get max concurrency from config, default to 8
511+
maxConcurrency := 8
512+
if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {
513+
maxConcurrency = s.config.API.BatchClassification.MaxConcurrency
514+
}
515+
516+
results := make([]services.Classification, len(texts))
517+
errors := make([]error, len(texts))
518+
519+
semaphore := make(chan struct{}, maxConcurrency)
520+
var wg sync.WaitGroup
521+
522+
for i, text := range texts {
523+
wg.Add(1)
524+
go func(index int, txt string) {
525+
defer wg.Done()
526+
semaphore <- struct{}{}
527+
defer func() { <-semaphore }()
528+
529+
result, err := s.classifySingleText(txt, options)
530+
if err != nil {
531+
errors[index] = err
532+
return
533+
}
534+
results[index] = result
535+
}(i, text)
536+
}
537+
538+
wg.Wait()
539+
540+
// Check for errors
541+
for i, err := range errors {
542+
if err != nil {
543+
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
544+
}
545+
}
546+
547+
return results, nil
548+
}
549+
550+
// classifySingleText processes a single text using existing service
551+
func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) {
552+
// Convert API options to service options
553+
var serviceOptions *services.IntentOptions
554+
if options != nil {
555+
serviceOptions = &services.IntentOptions{
556+
ReturnProbabilities: options.ReturnProbabilities,
557+
ConfidenceThreshold: options.ConfidenceThreshold,
558+
IncludeExplanation: options.IncludeExplanation,
559+
}
560+
}
561+
562+
individualReq := services.IntentRequest{
563+
Text: text,
564+
Options: serviceOptions,
565+
}
566+
567+
response, err := s.classificationSvc.ClassifyIntent(individualReq)
568+
if err != nil {
569+
return services.Classification{}, err
570+
}
571+
572+
return response.Classification, nil
573+
}
574+
575+
// calculateStatistics computes batch processing statistics
576+
func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) Statistics {
577+
categoryDistribution := make(map[string]int)
578+
var totalConfidence float64
579+
lowConfidenceCount := 0
580+
581+
for _, result := range results {
582+
if result.Category != "" {
583+
categoryDistribution[result.Category]++
584+
}
585+
totalConfidence += result.Confidence
586+
if result.Confidence < 0.7 {
587+
lowConfidenceCount++
588+
}
589+
}
590+
591+
avgConfidence := 0.0
592+
if len(results) > 0 {
593+
avgConfidence = totalConfidence / float64(len(results))
594+
}
595+
596+
return Statistics{
597+
CategoryDistribution: categoryDistribution,
598+
AvgConfidence: avgConfidence,
599+
LowConfidenceCount: lowConfidenceCount,
600+
}
601+
}

0 commit comments

Comments
 (0)