Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,11 @@ categories:
- model: phi4
score: 0.2
default_model: mistral-small3.1
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)
default_reasoning_effort: medium # Default reasoning effort level (low, medium, high)

# API Configuration
api:
batch_classification:
max_batch_size: 100 # Maximum number of texts in a single batch
concurrency_threshold: 5 # Switch to concurrent processing when batch size > this value
max_concurrency: 8 # Maximum number of concurrent goroutines
199 changes: 198 additions & 1 deletion src/semantic-router/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"net/http"
"runtime"
"sync"
"time"

"github.com/vllm-project/semantic-router/semantic-router/pkg/config"
Expand Down Expand Up @@ -46,6 +47,34 @@ type SystemInfo struct {
GPUAvailable bool `json:"gpu_available"`
}

// BatchClassificationRequest represents a batch classification request
type BatchClassificationRequest struct {
Texts []string `json:"texts"`
Options *ClassificationOptions `json:"options,omitempty"`
}

// BatchClassificationResponse represents the response from batch classification
type BatchClassificationResponse struct {
Results []services.Classification `json:"results"`
TotalCount int `json:"total_count"`
ProcessingTimeMs int64 `json:"processing_time_ms"`
Statistics CategoryClassificationStatistics `json:"statistics"`
}

// CategoryClassificationStatistics provides batch processing statistics
type CategoryClassificationStatistics struct {
CategoryDistribution map[string]int `json:"category_distribution"`
AvgConfidence float64 `json:"avg_confidence"`
LowConfidenceCount int `json:"low_confidence_count"`
}

// ClassificationOptions mirrors services.IntentOptions for API layer
type ClassificationOptions struct {
ReturnProbabilities bool `json:"return_probabilities,omitempty"`
ConfidenceThreshold float64 `json:"confidence_threshold,omitempty"`
IncludeExplanation bool `json:"include_explanation,omitempty"`
}

// StartClassificationAPI starts the Classification API server
func StartClassificationAPI(configPath string, port int) error {
// Load configuration
Expand Down Expand Up @@ -192,7 +221,64 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr
}

func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Batch classification not implemented yet")
start := time.Now()

var req BatchClassificationRequest
if err := s.parseJSONRequest(r, &req); err != nil {
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
return
}

// Input validation
if len(req.Texts) == 0 {
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")
return
}

// Get max batch size from config, default to 100
maxBatchSize := 100
if s.config != nil && s.config.API.BatchClassification.MaxBatchSize > 0 {
maxBatchSize = s.config.API.BatchClassification.MaxBatchSize
}

if len(req.Texts) > maxBatchSize {
s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE",
fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize))
return
}

// Get concurrency threshold from config, default to 5
concurrencyThreshold := 5
if s.config != nil && s.config.API.BatchClassification.ConcurrencyThreshold > 0 {
concurrencyThreshold = s.config.API.BatchClassification.ConcurrencyThreshold
}

// Process texts based on batch size
var results []services.Classification
var err error

if len(req.Texts) <= concurrencyThreshold {
results, err = s.processSequentially(req.Texts, req.Options)
} else {
results, err = s.processConcurrently(req.Texts, req.Options)
}

if err != nil {
s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error())
return
}

// Calculate statistics
statistics := s.calculateStatistics(results)

response := BatchClassificationResponse{
Results: results,
TotalCount: len(req.Texts),
ProcessingTimeMs: time.Since(start).Milliseconds(),
Statistics: statistics,
}

s.writeJSONResponse(w, http.StatusOK, response)
}

func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -405,3 +491,114 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {
GPUAvailable: false, // TODO: Implement GPU detection
}
}

// processSequentially handles small batches with sequential processing
func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
results := make([]services.Classification, len(texts))
for i, text := range texts {
result, err := s.classifySingleText(text, options)
if err != nil {
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
}
results[i] = result
}
return results, nil
}

// processConcurrently handles large batches with concurrent processing
func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
// Get max concurrency from config, default to 8
maxConcurrency := 8
if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {
maxConcurrency = s.config.API.BatchClassification.MaxConcurrency
}

results := make([]services.Classification, len(texts))
errors := make([]error, len(texts))

semaphore := make(chan struct{}, maxConcurrency)
var wg sync.WaitGroup

for i, text := range texts {
wg.Add(1)
go func(index int, txt string) {
defer wg.Done()
semaphore <- struct{}{}
defer func() { <-semaphore }()

// TODO: Refactor candle-binding to support batch mode for better performance
// This would allow processing multiple texts in a single model inference call
// instead of individual calls, significantly improving throughput
result, err := s.classifySingleText(txt, options)
if err != nil {
errors[index] = err
return
}
results[index] = result
}(i, text)
}

wg.Wait()

// Check for errors
for i, err := range errors {
if err != nil {
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
}
}

return results, nil
}

// classifySingleText processes a single text using existing service
func (s *ClassificationAPIServer) classifySingleText(text string, options *ClassificationOptions) (services.Classification, error) {
// Convert API options to service options
var serviceOptions *services.IntentOptions
if options != nil {
serviceOptions = &services.IntentOptions{
ReturnProbabilities: options.ReturnProbabilities,
ConfidenceThreshold: options.ConfidenceThreshold,
IncludeExplanation: options.IncludeExplanation,
}
}

individualReq := services.IntentRequest{
Text: text,
Options: serviceOptions,
}

response, err := s.classificationSvc.ClassifyIntent(individualReq)
if err != nil {
return services.Classification{}, err
}

return response.Classification, nil
}

// calculateStatistics computes batch processing statistics
func (s *ClassificationAPIServer) calculateStatistics(results []services.Classification) CategoryClassificationStatistics {
categoryDistribution := make(map[string]int)
var totalConfidence float64
lowConfidenceCount := 0

for _, result := range results {
if result.Category != "" {
categoryDistribution[result.Category]++
}
totalConfidence += result.Confidence
if result.Confidence < 0.7 {
lowConfidenceCount++
}
}

avgConfidence := 0.0
if len(results) > 0 {
avgConfidence = totalConfidence / float64(len(results))
}

return CategoryClassificationStatistics{
CategoryDistribution: categoryDistribution,
AvgConfidence: avgConfidence,
LowConfidenceCount: lowConfidenceCount,
}
}
Loading
Loading