Skip to content

Commit 35c8ec8

Browse files
authored
Merge pull request #24 from OneZero-Y/feat/batch-classification-api
feat: implement batch classification API
2 parents 7004215 + 76e3916 commit 35c8ec8

File tree

5 files changed

+641
-39
lines changed

5 files changed

+641
-39
lines changed

config/config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,11 @@ categories:
237237
- model: phi4
238238
score: 0.2
239239
default_model: mistral-small3.1
240-
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)
240+
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)
241+
242+
# API Configuration
243+
api:
244+
batch_classification:
245+
max_batch_size: 100 # Maximum number of texts in a single batch
246+
concurrency_threshold: 5 # Switch to concurrent processing when batch size > this value
247+
max_concurrency: 8 # Maximum number of concurrent goroutines

src/semantic-router/pkg/api/server.go

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log"
88
"net/http"
99
"runtime"
10+
"sync"
1011
"time"
1112

1213
"github.com/vllm-project/semantic-router/semantic-router/pkg/config"
@@ -46,6 +47,34 @@ type SystemInfo struct {
4647
GPUAvailable bool `json:"gpu_available"`
4748
}
4849

50+
// BatchClassificationRequest represents a batch classification request
51+
type BatchClassificationRequest struct {
52+
Texts []string `json:"texts"`
53+
Options *ClassificationOptions `json:"options,omitempty"`
54+
}
55+
56+
// BatchClassificationResponse represents the response from batch classification
57+
type BatchClassificationResponse struct {
58+
Results []services.Classification `json:"results"`
59+
TotalCount int `json:"total_count"`
60+
ProcessingTimeMs int64 `json:"processing_time_ms"`
61+
Statistics CategoryClassificationStatistics `json:"statistics"`
62+
}
63+
64+
// CategoryClassificationStatistics provides batch processing statistics
65+
type CategoryClassificationStatistics struct {
66+
CategoryDistribution map[string]int `json:"category_distribution"`
67+
AvgConfidence float64 `json:"avg_confidence"`
68+
LowConfidenceCount int `json:"low_confidence_count"`
69+
}
70+
71+
// ClassificationOptions mirrors services.IntentOptions for API layer
72+
type ClassificationOptions struct {
73+
ReturnProbabilities bool `json:"return_probabilities,omitempty"`
74+
ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"`
75+
IncludeExplanation bool `json:"include_explanation,omitempty"`
76+
}
77+
4978
// StartClassificationAPI starts the Classification API server
5079
func StartClassificationAPI(configPath string, port int) error {
5180
// Load configuration
@@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr
192221
}
193222

194223
func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) {
195-
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet")
224+
start := time.Now()
225+
226+
var req BatchClassificationRequest
227+
if err := s.parseJSONRequest(r, &req); err != nil {
228+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
229+
return
230+
}
231+
232+
// Input validation
233+
if len(req.Texts) == 0 {
234+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")
235+
return
236+
}
237+
238+
// Get max batch size from config, default to 100
239+
maxBatchSize := 100
240+
if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 {
241+
maxBatchSize = s.config.API.BatchClassification.MaxBatchSize
242+
}
243+
244+
if len(req.Texts) > maxBatchSize {
245+
s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE",
246+
fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize))
247+
return
248+
}
249+
250+
// Get concurrency threshold from config, default to 5
251+
concurrencyThreshold := 5
252+
if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 {
253+
concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold
254+
}
255+
256+
// Process texts based on batch size
257+
var results []services.Classification
258+
var err error
259+
260+
if len(req.Texts) <= concurrencyThreshold {
261+
results, err = s.processSequentially(req.Texts, req.Options)
262+
} else {
263+
results, err = s.processConcurrently(req.Texts, req.Options)
264+
}
265+
266+
if err != nil {
267+
s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error())
268+
return
269+
}
270+
271+
// Calculate statistics
272+
statistics := s.calculateStatistics(results)
273+
274+
response := BatchClassificationResponse{
275+
Results: results,
276+
TotalCount: len(req.Texts),
277+
ProcessingTimeMs: time.Since(start).Milliseconds(),
278+
Statistics: statistics,
279+
}
280+
281+
s.writeJSONResponse(w, http.StatusOK, response)
196282
}
197283

198284
func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {
@@ -405,3 +491,114 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {
405491
GPUAvailable: false, // TODO: Implement GPU detection
406492
}
407493
}
494+
495+
// processSequentially handles small batches with sequential processing
496+
func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
497+
results := make([]services.Classification, len(texts))
498+
for i, text := range texts {
499+
result, err := s.classifySingleText(text, options)
500+
if err != nil {
501+
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
502+
}
503+
results[i] = result
504+
}
505+
return results, nil
506+
}
507+
508+
// processConcurrently handles large batches with concurrent processing
509+
func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
510+
// Get max concurrency from config, default to 8
511+
maxConcurrency := 8
512+
if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {
513+
maxConcurrency = s.config.API.BatchClassification.MaxConcurrency
514+
}
515+
516+
results := make([]services.Classification, len(texts))
517+
errors := make([]error, len(texts))
518+
519+
semaphore := make(chan struct{}, maxConcurrency)
520+
var wg sync.WaitGroup
521+
522+
for i, text := range texts {
523+
wg.Add(1)
524+
go func(index int, txt string) {
525+
defer wg.Done()
526+
semaphore <- struct{}{}
527+
defer func() { <-semaphore }()
528+
529+
// TODO: Refactor candle-binding to support batch mode for better performance
530+
// This would allow processing multiple texts in a single model inference call
531+
// instead of individual calls, significantly improving throughput
532+
result, err := s.classifySingleText(txt, options)
533+
if err != nil {
534+
errors[index] = err
535+
return
536+
}
537+
results[index] = result
538+
}(i, text)
539+
}
540+
541+
wg.Wait()
542+
543+
// Check for errors
544+
for i, err := range errors {
545+
if err != nil {
546+
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
547+
}
548+
}
549+
550+
return results, nil
551+
}
552+
553+
// classifySingleText processes a single text using existing service
554+
func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) {
555+
// Convert API options to service options
556+
var serviceOptions *services.IntentOptions
557+
if options != nil {
558+
serviceOptions = &services.IntentOptions{
559+
ReturnProbabilities: options.ReturnProbabilities,
560+
ConfidenceThreshold: options.ConfidenceThreshold,
561+
IncludeExplanation: options.IncludeExplanation,
562+
}
563+
}
564+
565+
individualReq := services.IntentRequest{
566+
Text: text,
567+
Options: serviceOptions,
568+
}
569+
570+
response, err := s.classificationSvc.ClassifyIntent(individualReq)
571+
if err != nil {
572+
return services.Classification{}, err
573+
}
574+
575+
return response.Classification, nil
576+
}
577+
578+
// calculateStatistics computes batch processing statistics
579+
func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) CategoryClassificationStatistics {
580+
categoryDistribution := make(map[string]int)
581+
var totalConfidence float64
582+
lowConfidenceCount := 0
583+
584+
for _, result := range results {
585+
if result.Category != "" {
586+
categoryDistribution[result.Category]++
587+
}
588+
totalConfidence += result.Confidence
589+
if result.Confidence < 0.7 {
590+
lowConfidenceCount++
591+
}
592+
}
593+
594+
avgConfidence := 0.0
595+
if len(results) > 0 {
596+
avgConfidence = totalConfidence / float64(len(results))
597+
}
598+
599+
return CategoryClassificationStatistics{
600+
CategoryDistribution: categoryDistribution,
601+
AvgConfidence: avgConfidence,
602+
LowConfidenceCount: lowConfidenceCount,
603+
}
604+
}

0 commit comments

Comments
 (0)