Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,15 @@ api:
batch_classification:
max_batch_size: 100 # Maximum number of texts in a single batch
concurrency_threshold: 5 # Switch to concurrent processing when batch size > this value
max_concurrency: 8 # Maximum number of concurrent goroutines
max_concurrency: 8 # Maximum number of concurrent goroutines

# Metrics configuration for monitoring batch classification performance
metrics:
enabled: true # Enable comprehensive metrics collection
detailed_goroutine_tracking: true # Track individual goroutine lifecycle
high_resolution_timing: false # Use nanosecond precision timing
sample_rate: 1.0 # Collect metrics for all requests (1.0 = 100%, 0.5 = 50%)

# Histogram buckets for metrics (directly configure what you need)
duration_buckets: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30]
size_buckets: [1, 2, 5, 10, 20, 50, 100, 200]
64 changes: 64 additions & 0 deletions src/semantic-router/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/vllm-project/semantic-router/semantic-router/pkg/config"
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
"github.com/vllm-project/semantic-router/semantic-router/pkg/services"
)

Expand Down Expand Up @@ -91,6 +92,20 @@ func StartClassificationAPI(configPath string, port int) error {
classificationSvc = services.NewPlaceholderClassificationService()
}

// Initialize batch metrics configuration
if cfg != nil && cfg.API.BatchClassification.Metrics.Enabled {
metricsConfig := metrics.BatchMetricsConfig{
Enabled: cfg.API.BatchClassification.Metrics.Enabled,
DetailedGoroutineTracking: cfg.API.BatchClassification.Metrics.DetailedGoroutineTracking,
DurationBuckets: cfg.API.BatchClassification.Metrics.DurationBuckets,
SizeBuckets: cfg.API.BatchClassification.Metrics.SizeBuckets,
BatchSizeRanges: cfg.API.BatchClassification.Metrics.BatchSizeRanges,
HighResolutionTiming: cfg.API.BatchClassification.Metrics.HighResolutionTiming,
SampleRate: cfg.API.BatchClassification.Metrics.SampleRate,
}
metrics.SetBatchMetricsConfig(metricsConfig)
}

// Create server instance
apiServer := &ClassificationAPIServer{
classificationSvc: classificationSvc,
Expand Down Expand Up @@ -231,6 +246,8 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite

// Input validation
if len(req.Texts) == 0 {
// Record validation error in metrics
metrics.RecordBatchClassificationError("validation", "empty_texts")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")
return
}
Expand All @@ -242,6 +259,8 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
}

if len(req.Texts) > maxBatchSize {
// Record validation error in metrics
metrics.RecordBatchClassificationError("validation", "batch_too_large")
s.writeErrorResponse(w, http.StatusBadRequest, "BATCH_TOO_LARGE",
fmt.Sprintf("batch size cannot exceed %d texts", maxBatchSize))
return
Expand Down Expand Up @@ -494,10 +513,26 @@ func (s *ClassificationAPIServer) getSystemInfo() SystemInfo {

// processSequentially handles small batches with sequential processing
func (s *ClassificationAPIServer) processSequentially(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
start := time.Now()
processingType := "sequential"
batchSize := len(texts)

// Record request and batch size metrics
metrics.RecordBatchClassificationRequest(processingType)
metrics.RecordBatchSizeDistribution(processingType, batchSize)

// Defer recording processing time and text count
defer func() {
duration := time.Since(start).Seconds()
metrics.RecordBatchClassificationDuration(processingType, batchSize, duration)
metrics.RecordBatchClassificationTexts(processingType, batchSize)
}()

results := make([]services.Classification, len(texts))
for i, text := range texts {
result, err := s.classifySingleText(text, options)
if err != nil {
metrics.RecordBatchClassificationError(processingType, "classification_failed")
return nil, fmt.Errorf("failed to classify text at index %d: %w", i, err)
}
results[i] = result
Expand All @@ -507,6 +542,22 @@ func (s *ClassificationAPIServer) processSequentially(texts []string, options *C

// processConcurrently handles large batches with concurrent processing
func (s *ClassificationAPIServer) processConcurrently(texts []string, options *ClassificationOptions) ([]services.Classification, error) {
start := time.Now()
processingType := "concurrent"
batchSize := len(texts)
batchID := fmt.Sprintf("batch_%d", time.Now().UnixNano())

// Record request and batch size metrics
metrics.RecordBatchClassificationRequest(processingType)
metrics.RecordBatchSizeDistribution(processingType, batchSize)

// Defer recording processing time and text count
defer func() {
duration := time.Since(start).Seconds()
metrics.RecordBatchClassificationDuration(processingType, batchSize, duration)
metrics.RecordBatchClassificationTexts(processingType, batchSize)
}()

// Get max concurrency from config, default to 8
maxConcurrency := 8
if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {
Expand All @@ -523,6 +574,18 @@ func (s *ClassificationAPIServer) processConcurrently(texts []string, options *C
wg.Add(1)
go func(index int, txt string) {
defer wg.Done()

// Record goroutine start (if detailed tracking is enabled)
metricsConfig := metrics.GetBatchMetricsConfig()
if metricsConfig.DetailedGoroutineTracking {
metrics.ConcurrentGoroutines.WithLabelValues(batchID).Inc()

defer func() {
// Record goroutine end
metrics.ConcurrentGoroutines.WithLabelValues(batchID).Dec()
}()
}

semaphore <- struct{}{}
defer func() { <-semaphore }()

Expand All @@ -532,6 +595,7 @@ func (s *ClassificationAPIServer) processConcurrently(texts []string, options *C
result, err := s.classifySingleText(txt, options)
if err != nil {
errors[index] = err
metrics.RecordBatchClassificationError(processingType, "classification_failed")
return
}
results[index] = result
Expand Down
20 changes: 14 additions & 6 deletions src/semantic-router/pkg/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,17 @@ func TestBatchClassificationConfiguration(t *testing.T) {
config: &config.RouterConfig{
API: config.APIConfig{
BatchClassification: struct {
MaxBatchSize int `yaml:"max_batch_size,omitempty"`
ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"`
MaxConcurrency int `yaml:"max_concurrency,omitempty"`
MaxBatchSize int `yaml:"max_batch_size,omitempty"`
ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"`
MaxConcurrency int `yaml:"max_concurrency,omitempty"`
Metrics config.BatchClassificationMetricsConfig `yaml:"metrics,omitempty"`
}{
MaxBatchSize: 3, // Custom small limit
ConcurrencyThreshold: 2,
MaxConcurrency: 4,
Metrics: config.BatchClassificationMetricsConfig{
Enabled: true,
},
},
},
},
Expand Down Expand Up @@ -253,13 +257,17 @@ func TestBatchClassificationConfiguration(t *testing.T) {
config: &config.RouterConfig{
API: config.APIConfig{
BatchClassification: struct {
MaxBatchSize int `yaml:"max_batch_size,omitempty"`
ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"`
MaxConcurrency int `yaml:"max_concurrency,omitempty"`
MaxBatchSize int `yaml:"max_batch_size,omitempty"`
ConcurrencyThreshold int `yaml:"concurrency_threshold,omitempty"`
MaxConcurrency int `yaml:"max_concurrency,omitempty"`
Metrics config.BatchClassificationMetricsConfig `yaml:"metrics,omitempty"`
}{
MaxBatchSize: 10,
ConcurrencyThreshold: 3,
MaxConcurrency: 2,
Metrics: config.BatchClassificationMetricsConfig{
Enabled: true,
},
},
},
},
Expand Down
33 changes: 33 additions & 0 deletions src/semantic-router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,42 @@ type APIConfig struct {

// Maximum number of concurrent goroutines for batch processing
MaxConcurrency int `yaml:"max_concurrency,omitempty"`

// Metrics configuration for batch classification monitoring
Metrics BatchClassificationMetricsConfig `yaml:"metrics,omitempty"`
} `yaml:"batch_classification"`
}

// BatchClassificationMetricsConfig represents configuration for batch classification metrics
type BatchClassificationMetricsConfig struct {
// Sample rate for metrics collection (0.0-1.0, 1.0 means collect all metrics)
SampleRate float64 `yaml:"sample_rate,omitempty"`

// Batch size range labels for metrics (optional - uses sensible defaults if not specified)
// Default ranges: "1", "2-5", "6-10", "11-20", "21-50", "50+"
BatchSizeRanges []BatchSizeRangeConfig `yaml:"batch_size_ranges,omitempty"`

// Histogram buckets for metrics (directly configured)
DurationBuckets []float64 `yaml:"duration_buckets,omitempty"`
SizeBuckets []float64 `yaml:"size_buckets,omitempty"`

// Enable detailed metrics collection
Enabled bool `yaml:"enabled,omitempty"`

// Enable detailed goroutine tracking (may impact performance)
DetailedGoroutineTracking bool `yaml:"detailed_goroutine_tracking,omitempty"`

// Enable high-resolution timing (nanosecond precision)
HighResolutionTiming bool `yaml:"high_resolution_timing,omitempty"`
}

// BatchSizeRangeConfig defines a batch size range with its boundaries and label
type BatchSizeRangeConfig struct {
Min int `yaml:"min"`
Max int `yaml:"max"` // -1 means no upper limit
Label string `yaml:"label"`
}

// PromptGuardConfig represents configuration for the prompt guard jailbreak detection
type PromptGuardConfig struct {
// Enable prompt guard jailbreak detection
Expand Down
Loading
Loading