| 
7 | 7 | 	"log"  | 
8 | 8 | 	"net/http"  | 
9 | 9 | 	"runtime"  | 
 | 10 | +	"sync"  | 
10 | 11 | 	"time"  | 
11 | 12 | 
 
  | 
12 | 13 | 	"github.com/vllm-project/semantic-router/semantic-router/pkg/config"  | 
@@ -46,6 +47,34 @@ type SystemInfo struct {  | 
46 | 47 | 	GPUAvailable bool   `json:"gpu_available"`  | 
47 | 48 | }  | 
48 | 49 | 
 
  | 
 | 50 | +// BatchClassificationRequest represents a batch classification request  | 
 | 51 | +type BatchClassificationRequest struct {  | 
 | 52 | +	Texts   []string               `json:"texts"`  | 
 | 53 | +	Options *ClassificationOptions `json:"options,omitempty"`  | 
 | 54 | +}  | 
 | 55 | + | 
 | 56 | +// BatchClassificationResponse represents the response from batch classification  | 
 | 57 | +type BatchClassificationResponse struct {  | 
 | 58 | +	Results          []services.Classification `json:"results"`  | 
 | 59 | +	TotalCount       int                       `json:"total_count"`  | 
 | 60 | +	ProcessingTimeMs int64                     `json:"processing_time_ms"`  | 
 | 61 | +	Statistics       Statistics                `json:"statistics"`  | 
 | 62 | +}  | 
 | 63 | + | 
 | 64 | +// Statistics provides batch processing statistics  | 
 | 65 | +type Statistics struct {  | 
 | 66 | +	CategoryDistribution map[string]int `json:"category_distribution"`  | 
 | 67 | +	AvgConfidence        float64        `json:"avg_confidence"`  | 
 | 68 | +	LowConfidenceCount   int            `json:"low_confidence_count"`  | 
 | 69 | +}  | 
 | 70 | + | 
 | 71 | +// ClassificationOptions mirrors services.IntentOptions for API layer  | 
 | 72 | +type ClassificationOptions struct {  | 
 | 73 | +	ReturnProbabilities bool    `json:"return_probabilities,omitempty"`  | 
 | 74 | +	ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"`  | 
 | 75 | +	IncludeExplanation  bool    `json:"include_explanation,omitempty"`  | 
 | 76 | +}  | 
 | 77 | + | 
49 | 78 | // StartClassificationAPI starts the Classification API server  | 
50 | 79 | func StartClassificationAPI(configPath string, port int) error {  | 
51 | 80 | 	// Load configuration  | 
@@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr  | 
192 | 221 | }  | 
193 | 222 | 
 
  | 
194 | 223 | func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) {  | 
195 |  | -	s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet")  | 
 | 224 | +	start := time.Now()  | 
 | 225 | + | 
 | 226 | +	var req BatchClassificationRequest  | 
 | 227 | +	if err := s.parseJSONRequest(r, &req); err != nil {  | 
 | 228 | +		s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())  | 
 | 229 | +		return  | 
 | 230 | +	}  | 
 | 231 | + | 
 | 232 | +	// Input validation  | 
 | 233 | +	if len(req.Texts) == 0 {  | 
 | 234 | +		s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")  | 
 | 235 | +		return  | 
 | 236 | +	}  | 
 | 237 | + | 
 | 238 | +	// Get max batch size from config, default to 100  | 
 | 239 | +	maxBatchSize := 100  | 
 | 240 | +	if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 {  | 
 | 241 | +		maxBatchSize = s.config.API.BatchClassification.MaxBatchSize  | 
 | 242 | +	}  | 
 | 243 | + | 
 | 244 | +	if len(req.Texts) > maxBatchSize {  | 
 | 245 | +		s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE",  | 
 | 246 | +			fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize))  | 
 | 247 | +		return  | 
 | 248 | +	}  | 
 | 249 | + | 
 | 250 | +	// Get concurrency threshold from config, default to 5  | 
 | 251 | +	concurrencyThreshold := 5  | 
 | 252 | +	if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 {  | 
 | 253 | +		concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold  | 
 | 254 | +	}  | 
 | 255 | + | 
 | 256 | +	// Process texts based on batch size  | 
 | 257 | +	var results []services.Classification  | 
 | 258 | +	var err error  | 
 | 259 | + | 
 | 260 | +	if len(req.Texts) <= concurrencyThreshold {  | 
 | 261 | +		results, err = s.processSequentially(req.Texts, req.Options)  | 
 | 262 | +	} else {  | 
 | 263 | +		results, err = s.processConcurrently(req.Texts, req.Options)  | 
 | 264 | +	}  | 
 | 265 | + | 
 | 266 | +	if err != nil {  | 
 | 267 | +		s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error())  | 
 | 268 | +		return  | 
 | 269 | +	}  | 
 | 270 | + | 
 | 271 | +	// Calculate statistics  | 
 | 272 | +	statistics := s.calculateStatistics(results)  | 
 | 273 | + | 
 | 274 | +	response := BatchClassificationResponse{  | 
 | 275 | +		Results:          results,  | 
 | 276 | +		TotalCount:       len(req.Texts),  | 
 | 277 | +		ProcessingTimeMs: time.Since(start).Milliseconds(),  | 
 | 278 | +		Statistics:       statistics,  | 
 | 279 | +	}  | 
 | 280 | + | 
 | 281 | +	s.writeJSONResponse(w, http.StatusOK, response)  | 
196 | 282 | }  | 
197 | 283 | 
 
  | 
198 | 284 | func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {  | 
@@ -405,3 +491,111 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {  | 
405 | 491 | 		GPUAvailable: false, // TODO: Implement GPU detection  | 
406 | 492 | 	}  | 
407 | 493 | }  | 
 | 494 | + | 
 | 495 | +// processSequentially handles small batches with sequential processing  | 
 | 496 | +func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) {  | 
 | 497 | +	results := make([]services.Classification, len(texts))  | 
 | 498 | +	for i, text := range texts {  | 
 | 499 | +		result, err := s.classifySingleText(text, options)  | 
 | 500 | +		if err != nil {  | 
 | 501 | +			return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)  | 
 | 502 | +		}  | 
 | 503 | +		results[i] = result  | 
 | 504 | +	}  | 
 | 505 | +	return results, nil  | 
 | 506 | +}  | 
 | 507 | + | 
 | 508 | +// processConcurrently handles large batches with concurrent processing  | 
 | 509 | +func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) {  | 
 | 510 | +	// Get max concurrency from config, default to 8  | 
 | 511 | +	maxConcurrency := 8  | 
 | 512 | +	if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {  | 
 | 513 | +		maxConcurrency = s.config.API.BatchClassification.MaxConcurrency  | 
 | 514 | +	}  | 
 | 515 | + | 
 | 516 | +	results := make([]services.Classification, len(texts))  | 
 | 517 | +	errors := make([]error, len(texts))  | 
 | 518 | + | 
 | 519 | +	semaphore := make(chan struct{}, maxConcurrency)  | 
 | 520 | +	var wg sync.WaitGroup  | 
 | 521 | + | 
 | 522 | +	for i, text := range texts {  | 
 | 523 | +		wg.Add(1)  | 
 | 524 | +		go func(index int, txt string) {  | 
 | 525 | +			defer wg.Done()  | 
 | 526 | +			semaphore <- struct{}{}  | 
 | 527 | +			defer func() { <-semaphore }()  | 
 | 528 | + | 
 | 529 | +			result, err := s.classifySingleText(txt, options)  | 
 | 530 | +			if err != nil {  | 
 | 531 | +				errors[index] = err  | 
 | 532 | +				return  | 
 | 533 | +			}  | 
 | 534 | +			results[index] = result  | 
 | 535 | +		}(i, text)  | 
 | 536 | +	}  | 
 | 537 | + | 
 | 538 | +	wg.Wait()  | 
 | 539 | + | 
 | 540 | +	// Check for errors  | 
 | 541 | +	for i, err := range errors {  | 
 | 542 | +		if err != nil {  | 
 | 543 | +			return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)  | 
 | 544 | +		}  | 
 | 545 | +	}  | 
 | 546 | + | 
 | 547 | +	return results, nil  | 
 | 548 | +}  | 
 | 549 | + | 
 | 550 | +// classifySingleText processes a single text using existing service  | 
 | 551 | +func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) {  | 
 | 552 | +	// Convert API options to service options  | 
 | 553 | +	var serviceOptions *services.IntentOptions  | 
 | 554 | +	if options != nil {  | 
 | 555 | +		serviceOptions = &services.IntentOptions{  | 
 | 556 | +			ReturnProbabilities: options.ReturnProbabilities,  | 
 | 557 | +			ConfidenceThreshold: options.ConfidenceThreshold,  | 
 | 558 | +			IncludeExplanation:  options.IncludeExplanation,  | 
 | 559 | +		}  | 
 | 560 | +	}  | 
 | 561 | + | 
 | 562 | +	individualReq := services.IntentRequest{  | 
 | 563 | +		Text:    text,  | 
 | 564 | +		Options: serviceOptions,  | 
 | 565 | +	}  | 
 | 566 | + | 
 | 567 | +	response, err := s.classificationSvc.ClassifyIntent(individualReq)  | 
 | 568 | +	if err != nil {  | 
 | 569 | +		return services.Classification{}, err  | 
 | 570 | +	}  | 
 | 571 | + | 
 | 572 | +	return response.Classification, nil  | 
 | 573 | +}  | 
 | 574 | + | 
 | 575 | +// calculateStatistics computes batch processing statistics  | 
 | 576 | +func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) Statistics {  | 
 | 577 | +	categoryDistribution := make(map[string]int)  | 
 | 578 | +	var totalConfidence float64  | 
 | 579 | +	lowConfidenceCount := 0  | 
 | 580 | + | 
 | 581 | +	for _, result := range results {  | 
 | 582 | +		if result.Category != "" {  | 
 | 583 | +			categoryDistribution[result.Category]++  | 
 | 584 | +		}  | 
 | 585 | +		totalConfidence += result.Confidence  | 
 | 586 | +		if result.Confidence < 0.7 {  | 
 | 587 | +			lowConfidenceCount++  | 
 | 588 | +		}  | 
 | 589 | +	}  | 
 | 590 | + | 
 | 591 | +	avgConfidence := 0.0  | 
 | 592 | +	if len(results) > 0 {  | 
 | 593 | +		avgConfidence = totalConfidence / float64(len(results))  | 
 | 594 | +	}  | 
 | 595 | + | 
 | 596 | +	return Statistics{  | 
 | 597 | +		CategoryDistribution: categoryDistribution,  | 
 | 598 | +		AvgConfidence:        avgConfidence,  | 
 | 599 | +		LowConfidenceCount:   lowConfidenceCount,  | 
 | 600 | +	}  | 
 | 601 | +}  | 
0 commit comments