|
7 | 7 | "log" |
8 | 8 | "net/http" |
9 | 9 | "runtime" |
| 10 | + "sync" |
10 | 11 | "time" |
11 | 12 |
|
12 | 13 | "github.com/vllm-project/semantic-router/semantic-router/pkg/config" |
@@ -46,6 +47,34 @@ type SystemInfo struct { |
46 | 47 | GPUAvailable bool `json:"gpu_available"` |
47 | 48 | } |
48 | 49 |
|
| 50 | +// BatchClassificationRequest represents a batch classification request |
| 51 | +type BatchClassificationRequest struct { |
| 52 | + Texts []string `json:"texts"` |
| 53 | + Options *ClassificationOptions `json:"options,omitempty"` |
| 54 | +} |
| 55 | + |
| 56 | +// BatchClassificationResponse represents the response from batch classification |
| 57 | +type BatchClassificationResponse struct { |
| 58 | + Results []services.Classification `json:"results"` |
| 59 | + TotalCount int `json:"total_count"` |
| 60 | + ProcessingTimeMs int64 `json:"processing_time_ms"` |
| 61 | + Statistics CategoryClassificationStatistics `json:"statistics"` |
| 62 | +} |
| 63 | + |
| 64 | +// CategoryClassificationStatistics provides batch processing statistics |
| 65 | +type CategoryClassificationStatistics struct { |
| 66 | + CategoryDistribution map[string]int `json:"category_distribution"` |
| 67 | + AvgConfidence float64 `json:"avg_confidence"` |
| 68 | + LowConfidenceCount int `json:"low_confidence_count"` |
| 69 | +} |
| 70 | + |
| 71 | +// ClassificationOptions mirrors services.IntentOptions for API layer |
| 72 | +type ClassificationOptions struct { |
| 73 | + ReturnProbabilities bool `json:"return_probabilities,omitempty"` |
| 74 | + ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"` |
| 75 | + IncludeExplanation bool `json:"include_explanation,omitempty"` |
| 76 | +} |
| 77 | + |
49 | 78 | // StartClassificationAPI starts the Classification API server |
50 | 79 | func StartClassificationAPI(configPath string, port int) error { |
51 | 80 | // Load configuration |
@@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr |
192 | 221 | } |
193 | 222 |
|
194 | 223 | func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) { |
195 | | - s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet") |
| 224 | + start := time.Now() |
| 225 | + |
| 226 | + var req BatchClassificationRequest |
| 227 | + if err := s.parseJSONRequest(r, &req); err != nil { |
| 228 | + s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error()) |
| 229 | + return |
| 230 | + } |
| 231 | + |
| 232 | + // Input validation |
| 233 | + if len(req.Texts) == 0 { |
| 234 | + s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty") |
| 235 | + return |
| 236 | + } |
| 237 | + |
| 238 | + // Get max batch size from config, default to 100 |
| 239 | + maxBatchSize := 100 |
| 240 | + if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 { |
| 241 | + maxBatchSize = s.config.API.BatchClassification.MaxBatchSize |
| 242 | + } |
| 243 | + |
| 244 | + if len(req.Texts) > maxBatchSize { |
| 245 | + s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE", |
| 246 | + fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize)) |
| 247 | + return |
| 248 | + } |
| 249 | + |
| 250 | + // Get concurrency threshold from config, default to 5 |
| 251 | + concurrencyThreshold := 5 |
| 252 | + if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 { |
| 253 | + concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold |
| 254 | + } |
| 255 | + |
| 256 | + // Process texts based on batch size |
| 257 | + var results []services.Classification |
| 258 | + var err error |
| 259 | + |
| 260 | + if len(req.Texts) <= concurrencyThreshold { |
| 261 | + results, err = s.processSequentially(req.Texts, req.Options) |
| 262 | + } else { |
| 263 | + results, err = s.processConcurrently(req.Texts, req.Options) |
| 264 | + } |
| 265 | + |
| 266 | + if err != nil { |
| 267 | + s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error()) |
| 268 | + return |
| 269 | + } |
| 270 | + |
| 271 | + // Calculate statistics |
| 272 | + statistics := s.calculateStatistics(results) |
| 273 | + |
| 274 | + response := BatchClassificationResponse{ |
| 275 | + Results: results, |
| 276 | + TotalCount: len(req.Texts), |
| 277 | + ProcessingTimeMs: time.Since(start).Milliseconds(), |
| 278 | + Statistics: statistics, |
| 279 | + } |
| 280 | + |
| 281 | + s.writeJSONResponse(w, http.StatusOK, response) |
196 | 282 | } |
197 | 283 |
|
198 | 284 | func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) { |
@@ -405,3 +491,114 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo { |
405 | 491 | GPUAvailable: false, // TODO: Implement GPU detection |
406 | 492 | } |
407 | 493 | } |
| 494 | + |
| 495 | +// processSequentially handles small batches with sequential processing |
| 496 | +func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) { |
| 497 | + results := make([]services.Classification, len(texts)) |
| 498 | + for i, text := range texts { |
| 499 | + result, err := s.classifySingleText(text, options) |
| 500 | + if err != nil { |
| 501 | + return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err) |
| 502 | + } |
| 503 | + results[i] = result |
| 504 | + } |
| 505 | + return results, nil |
| 506 | +} |
| 507 | + |
| 508 | +// processConcurrently handles large batches with concurrent processing |
| 509 | +func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) { |
| 510 | + // Get max concurrency from config, default to 8 |
| 511 | + maxConcurrency := 8 |
| 512 | + if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 { |
| 513 | + maxConcurrency = s.config.API.BatchClassification.MaxConcurrency |
| 514 | + } |
| 515 | + |
| 516 | + results := make([]services.Classification, len(texts)) |
| 517 | + errors := make([]error, len(texts)) |
| 518 | + |
| 519 | + semaphore := make(chan struct{}, maxConcurrency) |
| 520 | + var wg sync.WaitGroup |
| 521 | + |
| 522 | + for i, text := range texts { |
| 523 | + wg.Add(1) |
| 524 | + go func(index int, txt string) { |
| 525 | + defer wg.Done() |
| 526 | + semaphore <- struct{}{} |
| 527 | + defer func() { <-semaphore }() |
| 528 | + |
| 529 | + // TODO: Refactor candle-binding to support batch mode for better performance |
| 530 | + // This would allow processing multiple texts in a single model inference call |
| 531 | + // instead of individual calls, significantly improving throughput |
| 532 | + result, err := s.classifySingleText(txt, options) |
| 533 | + if err != nil { |
| 534 | + errors[index] = err |
| 535 | + return |
| 536 | + } |
| 537 | + results[index] = result |
| 538 | + }(i, text) |
| 539 | + } |
| 540 | + |
| 541 | + wg.Wait() |
| 542 | + |
| 543 | + // Check for errors |
| 544 | + for i, err := range errors { |
| 545 | + if err != nil { |
| 546 | + return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err) |
| 547 | + } |
| 548 | + } |
| 549 | + |
| 550 | + return results, nil |
| 551 | +} |
| 552 | + |
| 553 | +// classifySingleText processes a single text using existing service |
| 554 | +func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) { |
| 555 | + // Convert API options to service options |
| 556 | + var serviceOptions *services.IntentOptions |
| 557 | + if options != nil { |
| 558 | + serviceOptions = &services.IntentOptions{ |
| 559 | + ReturnProbabilities: options.ReturnProbabilities, |
| 560 | + ConfidenceThreshold: options.ConfidenceThreshold, |
| 561 | + IncludeExplanation: options.IncludeExplanation, |
| 562 | + } |
| 563 | + } |
| 564 | + |
| 565 | + individualReq := services.IntentRequest{ |
| 566 | + Text: text, |
| 567 | + Options: serviceOptions, |
| 568 | + } |
| 569 | + |
| 570 | + response, err := s.classificationSvc.ClassifyIntent(individualReq) |
| 571 | + if err != nil { |
| 572 | + return services.Classification{}, err |
| 573 | + } |
| 574 | + |
| 575 | + return response.Classification, nil |
| 576 | +} |
| 577 | + |
| 578 | +// calculateStatistics computes batch processing statistics |
| 579 | +func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) CategoryClassificationStatistics { |
| 580 | + categoryDistribution := make(map[string]int) |
| 581 | + var totalConfidence float64 |
| 582 | + lowConfidenceCount := 0 |
| 583 | + |
| 584 | + for _, result := range results { |
| 585 | + if result.Category != "" { |
| 586 | + categoryDistribution[result.Category]++ |
| 587 | + } |
| 588 | + totalConfidence += result.Confidence |
| 589 | + if result.Confidence < 0.7 { |
| 590 | + lowConfidenceCount++ |
| 591 | + } |
| 592 | + } |
| 593 | + |
| 594 | + avgConfidence := 0.0 |
| 595 | + if len(results) > 0 { |
| 596 | + avgConfidence = totalConfidence / float64(len(results)) |
| 597 | + } |
| 598 | + |
| 599 | + return CategoryClassificationStatistics{ |
| 600 | + CategoryDistribution: categoryDistribution, |
| 601 | + AvgConfidence: avgConfidence, |
| 602 | + LowConfidenceCount: lowConfidenceCount, |
| 603 | + } |
| 604 | +} |
0 commit comments