Skip to content

Commit 35f0d70

Browse files
authored
Refactor: use worker pool for batch classification concurrency (#115)
Signed-off-by: cryo <[email protected]>
1 parent 352e1a5 commit 35f0d70

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

src/semantic-router/pkg/api/server.go

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -563,45 +563,53 @@ func (s *ClassificationAPIServer) processConcurrently(texts []string, options *C
563563
if s.config != nil && s.config.API.BatchClassification.MaxConcurrency > 0 {
564564
maxConcurrency = s.config.API.BatchClassification.MaxConcurrency
565565
}
566+
// Get the actual number of workers to start
567+
numWorkers := min(len(texts), maxConcurrency)
566568

567569
results := make([]services.Classification, len(texts))
568570
errors := make([]error, len(texts))
569571

570-
semaphore := make(chan struct{}, maxConcurrency)
572+
// Create a channel for tasks
573+
taskChan := make(chan int, batchSize)
571574
var wg sync.WaitGroup
572575

573-
for i, text := range texts {
576+
// Start a fixed number of worker goroutines
577+
for i := range numWorkers {
574578
wg.Add(1)
575-
go func(index int, txt string) {
579+
go func(workerID int) {
576580
defer wg.Done()
577581

578582
// Record goroutine start (if detailed tracking is enabled)
579583
metricsConfig := metrics.GetBatchMetricsConfig()
580584
if metricsConfig.DetailedGoroutineTracking {
581585
metrics.ConcurrentGoroutines.WithLabelValues(batchID).Inc()
582-
583-
defer func() {
584-
// Record goroutine end
585-
metrics.ConcurrentGoroutines.WithLabelValues(batchID).Dec()
586-
}()
586+
// Record goroutine end
587+
defer metrics.ConcurrentGoroutines.WithLabelValues(batchID).Dec()
587588
}
588589

589-
semaphore <- struct{}{}
590-
defer func() { <-semaphore }()
591-
592-
// TODO: Refactor candle-binding to support batch mode for better performance
593-
// This would allow processing multiple texts in a single model inference call
594-
// instead of individual calls, significantly improving throughput
595-
result, err := s.classifySingleText(txt, options)
596-
if err != nil {
597-
errors[index] = err
598-
metrics.RecordBatchClassificationError(processingType, "classification_failed")
599-
return
590+
// Worker goroutine loops to process tasks from the channel
591+
for taskIndex := range taskChan {
592+
// TODO: Refactor candle-binding to support batch mode for better performance
593+
// This would allow processing multiple texts in a single model inference call
594+
// instead of individual calls, significantly improving throughput
595+
result, err := s.classifySingleText(texts[taskIndex], options)
596+
if err != nil {
597+
errors[taskIndex] = err
598+
metrics.RecordBatchClassificationError(processingType, "classification_failed")
599+
continue
600+
}
601+
results[taskIndex] = result
600602
}
601-
results[index] = result
602-
}(i, text)
603+
}(i)
604+
}
605+
606+
// Send tasks to the channel
607+
for i := range texts {
608+
taskChan <- i
603609
}
610+
close(taskChan)
604611

612+
// Wait for all workers to finish processing
605613
wg.Wait()
606614

607615
// Check for errors

0 commit comments

Comments
 (0)