Skip to content

Commit 2570877

Browse files
committed
perf: Allow Concurrent Text Classification
Signed-off-by: cryo <[email protected]>
1 parent ea63386 commit 2570877

File tree

2 files changed

+137
-31
lines changed

2 files changed

+137
-31
lines changed

candle-binding/semantic-router_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package candle_binding
22

33
import (
4+
"context"
5+
"fmt"
46
"math"
57
"runtime"
68
"strings"
79
"sync"
10+
"sync/atomic"
811
"testing"
912
"time"
1013
)
@@ -406,6 +409,99 @@ func TestModernBERTClassifiers(t *testing.T) {
406409
})
407410
}
408411

412+
func TestModernBertClassifier_ConcurrentClassificationSafety(t *testing.T) {
413+
// init
414+
if err := InitModernBertClassifier(CategoryClassifierModelPath, true); err != nil {
415+
t.Skipf("ModernBERT classifier not available: %v", err)
416+
}
417+
418+
texts := []string{
419+
"This is a test sentence for classification",
420+
"Another example text to classify with ModernBERT",
421+
"The quick brown fox jumps over the lazy dog",
422+
"Machine learning models are becoming more efficient",
423+
"Natural language processing is a fascinating field",
424+
}
425+
426+
// Baseline (single-threaded)
427+
baseline := make(map[string]ClassResult, len(texts))
428+
for _, txt := range texts {
429+
res, err := ClassifyModernBertText(txt)
430+
if err != nil {
431+
t.Fatalf("baseline call failed for %q: %v", txt, err)
432+
}
433+
baseline[txt] = res
434+
}
435+
436+
const numGoroutines = 10
437+
const iterationsPerGoroutine = 5
438+
439+
var wg sync.WaitGroup
440+
errCh := make(chan error, numGoroutines*iterationsPerGoroutine)
441+
var total int64
442+
443+
ctx, cancel := context.WithCancel(context.Background())
444+
defer cancel()
445+
446+
for g := range numGoroutines {
447+
wg.Add(1)
448+
go func(id int) {
449+
defer wg.Done()
450+
for i := range iterationsPerGoroutine {
451+
select {
452+
case <-ctx.Done():
453+
return
454+
default:
455+
}
456+
457+
txt := texts[(id+i)%len(texts)]
458+
res, err := ClassifyModernBertText(txt)
459+
if err != nil {
460+
errCh <- fmt.Errorf("gor %d iter %d classify error: %v", id, i, err)
461+
cancel() // stop early
462+
return
463+
}
464+
465+
// Strict: class must match baseline
466+
base := baseline[txt]
467+
if res.Class != base.Class {
468+
errCh <- fmt.Errorf("gor %d iter %d: class mismatch for %q: got %d expected %d", id, i, txt, res.Class, base.Class)
469+
cancel()
470+
return
471+
}
472+
473+
// Allow small FP differences
474+
if math.Abs(float64(res.Confidence)-float64(base.Confidence)) > 0.05 {
475+
errCh <- fmt.Errorf("gor %d iter %d: confidence mismatch for %q: got %f expected %f", id, i, txt, res.Confidence, base.Confidence)
476+
cancel()
477+
return
478+
}
479+
480+
atomic.AddInt64(&total, 1)
481+
}
482+
}(g)
483+
}
484+
485+
wg.Wait()
486+
close(errCh)
487+
488+
errs := 0
489+
for e := range errCh {
490+
t.Error(e)
491+
errs++
492+
}
493+
if errs > 0 {
494+
t.Fatalf("concurrency test failed with %d errors", errs)
495+
}
496+
497+
expected := int64(numGoroutines * iterationsPerGoroutine)
498+
if total != expected {
499+
t.Fatalf("expected %d successful results, got %d", expected, total)
500+
}
501+
502+
t.Logf("concurrent test OK: goroutines=%d iterations=%d", numGoroutines, iterationsPerGoroutine)
503+
}
504+
409505
// TestModernBERTPIITokenClassification tests the PII token classification functionality
410506
func TestModernBERTPIITokenClassification(t *testing.T) {
411507
// Test data with various PII entities

candle-binding/src/modernbert.rs

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ pub struct ModernBertClassifier {
268268
}
269269

270270
lazy_static::lazy_static! {
271-
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
271+
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<Arc<ModernBertClassifier>>>> = Arc::new(Mutex::new(None));
272272
static ref MODERNBERT_PII_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
273273
static ref MODERNBERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
274274
}
@@ -830,7 +830,7 @@ pub extern "C" fn init_modernbert_classifier(model_id: *const c_char, use_cpu: b
830830
match ModernBertClassifier::new(model_id, use_cpu) {
831831
Ok(classifier) => {
832832
let mut bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
833-
*bert_opt = Some(classifier);
833+
*bert_opt = Some(Arc::new(classifier));
834834
true
835835
}
836836
Err(e) => {
@@ -930,18 +930,23 @@ pub extern "C" fn classify_modernbert_text(text: *const c_char) -> ModernBertCla
930930
}
931931
};
932932

933-
let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
934-
match &*bert_opt {
935-
Some(classifier) => match classifier.classify_text(text) {
936-
Ok((class_idx, confidence)) => ModernBertClassificationResult {
937-
class: class_idx as i32,
938-
confidence,
939-
},
940-
Err(e) => {
941-
eprintln!("Error classifying text with ModernBERT: {e}");
942-
default_result
933+
let classifier_arc_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
934+
match &*classifier_arc_opt {
935+
Some(classifier_arc) => {
936+
let classifier = classifier_arc.clone();
937+
drop(classifier_arc_opt);
938+
939+
match classifier.classify_text(text) {
940+
Ok((class_idx, confidence)) => ModernBertClassificationResult {
941+
class: class_idx as i32,
942+
confidence,
943+
},
944+
Err(e) => {
945+
eprintln!("Error classifying text with ModernBERT: {e}");
946+
default_result
947+
}
943948
}
944-
},
949+
}
945950
None => {
946951
eprintln!("ModernBERT classifier not initialized");
947952
default_result
@@ -968,26 +973,31 @@ pub extern "C" fn classify_modernbert_text_with_probabilities(
968973
}
969974
};
970975

971-
let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
972-
match &*bert_opt {
973-
Some(classifier) => match classifier.classify_text_with_probs(text) {
974-
Ok((class_idx, confidence, probabilities)) => {
975-
// Allocate memory for probabilities array
976-
let prob_len = probabilities.len();
977-
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;
978-
979-
ModernBertClassificationResultWithProbs {
980-
class: class_idx as i32,
981-
confidence,
982-
probabilities: prob_ptr,
983-
num_classes: prob_len as i32,
976+
let classifier_arc_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
977+
match &*classifier_arc_opt {
978+
Some(classifier_arc) => {
979+
let classifier = classifier_arc.clone();
980+
drop(classifier_arc_opt);
981+
982+
match classifier.classify_text_with_probs(text) {
983+
Ok((class_idx, confidence, probabilities)) => {
984+
// Allocate memory for probabilities array
985+
let prob_len = probabilities.len();
986+
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;
987+
988+
ModernBertClassificationResultWithProbs {
989+
class: class_idx as i32,
990+
confidence,
991+
probabilities: prob_ptr,
992+
num_classes: prob_len as i32,
993+
}
994+
}
995+
Err(e) => {
996+
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
997+
default_result
984998
}
985999
}
986-
Err(e) => {
987-
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
988-
default_result
989-
}
990-
},
1000+
}
9911001
None => {
9921002
eprintln!("ModernBERT classifier not initialized");
9931003
default_result

0 commit comments

Comments
 (0)