diff --git a/Makefile b/Makefile index 026ed8c1..da64fca3 100644 --- a/Makefile +++ b/Makefile @@ -141,7 +141,7 @@ check-go-mod-tidy: test-binding: rust @echo "Running Go tests with static library..." @export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \ - cd candle-binding && CGO_ENABLED=1 go test -v + cd candle-binding && CGO_ENABLED=1 go test -v -race # Test with the candle-binding library test-category-classifier: rust diff --git a/candle-binding/semantic-router_test.go b/candle-binding/semantic-router_test.go index 845aa4a3..f911769a 100644 --- a/candle-binding/semantic-router_test.go +++ b/candle-binding/semantic-router_test.go @@ -1,10 +1,13 @@ package candle_binding import ( + "context" + "fmt" "math" "runtime" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -476,6 +479,99 @@ func TestModernBERTClassifiers(t *testing.T) { }) } +func TestModernBertClassifier_ConcurrentClassificationSafety(t *testing.T) { + // init + if err := InitModernBertClassifier(CategoryClassifierModelPath, true); err != nil { + t.Skipf("ModernBERT classifier not available: %v", err) + } + + texts := []string{ + "This is a test sentence for classification", + "Another example text to classify with ModernBERT", + "The quick brown fox jumps over the lazy dog", + "Machine learning models are becoming more efficient", + "Natural language processing is a fascinating field", + } + + // Baseline (single-threaded) + baseline := make(map[string]ClassResult, len(texts)) + for _, txt := range texts { + res, err := ClassifyModernBertText(txt) + if err != nil { + t.Fatalf("baseline call failed for %q: %v", txt, err) + } + baseline[txt] = res + } + + const numGoroutines = 10 + const iterationsPerGoroutine = 5 + + var wg sync.WaitGroup + errCh := make(chan error, numGoroutines*iterationsPerGoroutine) + var total int64 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for g := range numGoroutines { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := range iterationsPerGoroutine { + select { + case <-ctx.Done(): + return + default: + } + + txt := texts[(id+i)%len(texts)] + res, err := ClassifyModernBertText(txt) + if err != nil { + errCh <- fmt.Errorf("gor %d iter %d classify error: %v", id, i, err) + cancel() // stop early + return + } + + // Strict: class must match baseline + base := baseline[txt] + if res.Class != base.Class { + errCh <- fmt.Errorf("gor %d iter %d: class mismatch for %q: got %d expected %d", id, i, txt, res.Class, base.Class) + cancel() + return + } + + // Allow small FP differences + if math.Abs(float64(res.Confidence)-float64(base.Confidence)) > 0.05 { + errCh <- fmt.Errorf("gor %d iter %d: confidence mismatch for %q: got %f expected %f", id, i, txt, res.Confidence, base.Confidence) + cancel() + return + } + + atomic.AddInt64(&total, 1) + } + }(g) + } + + wg.Wait() + close(errCh) + + errs := 0 + for e := range errCh { + t.Error(e) + errs++ + } + if errs > 0 { + t.Fatalf("concurrency test failed with %d errors", errs) + } + + expected := int64(numGoroutines * iterationsPerGoroutine) + if total != expected { + t.Fatalf("expected %d successful results, got %d", expected, total) + } + + t.Logf("concurrent test OK: goroutines=%d iterations=%d", numGoroutines, iterationsPerGoroutine) +} + // TestModernBERTPIITokenClassification tests the PII token classification functionality func TestModernBERTPIITokenClassification(t *testing.T) { // Test data with various PII entities diff --git a/candle-binding/src/modernbert.rs b/candle-binding/src/modernbert.rs index c81277d7..16120717 100644 --- a/candle-binding/src/modernbert.rs +++ b/candle-binding/src/modernbert.rs @@ -268,7 +268,7 @@ pub struct ModernBertClassifier { } lazy_static::lazy_static! { - static ref MODERNBERT_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); + static ref MODERNBERT_CLASSIFIER: Arc>>> = Arc::new(Mutex::new(None)); static ref MODERNBERT_PII_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); static ref MODERNBERT_JAILBREAK_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); } @@ -830,7 +830,7 @@ pub extern "C" fn init_modernbert_classifier(model_id: *const c_char, use_cpu: b match ModernBertClassifier::new(model_id, use_cpu) { Ok(classifier) => { let mut bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap(); - *bert_opt = Some(classifier); + *bert_opt = Some(Arc::new(classifier)); true } Err(e) => { @@ -930,20 +930,23 @@ pub extern "C" fn classify_modernbert_text(text: *const c_char) -> ModernBertCla } }; - let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap(); - match &*bert_opt { - Some(classifier) => match classifier.classify_text(text) { - Ok((class_idx, confidence)) => ModernBertClassificationResult { - class: class_idx as i32, - confidence, - }, - Err(e) => { - eprintln!("Error classifying text with ModernBERT: {e}"); - default_result - } - }, - None => { + let classifier_arc = { + let guard = MODERNBERT_CLASSIFIER.lock().unwrap(); + if let Some(arc) = guard.as_ref() { + Arc::clone(arc) + } else { eprintln!("ModernBERT classifier not initialized"); + return default_result; + } + }; + + match classifier_arc.classify_text(text) { + Ok((class_idx, confidence)) => ModernBertClassificationResult { + class: class_idx as i32, + confidence, + }, + Err(e) => { + eprintln!("Error classifying text with ModernBERT: {e}"); default_result } } @@ -968,28 +971,31 @@ pub extern "C" fn classify_modernbert_text_with_probabilities( } }; - let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap(); - match &*bert_opt { - Some(classifier) => match classifier.classify_text_with_probs(text) { - Ok((class_idx, confidence, probabilities)) => { - // Allocate memory for probabilities array - let prob_len = probabilities.len(); - let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32; - - ModernBertClassificationResultWithProbs { - class: class_idx as i32, - confidence, - probabilities: prob_ptr, - num_classes: prob_len as i32, - } - } - Err(e) => { - eprintln!("Error classifying text with probabilities using ModernBERT: {e}"); - default_result - } - }, - None => { + let classifier_arc = { + let guard = MODERNBERT_CLASSIFIER.lock().unwrap(); + if let Some(arc) = guard.as_ref() { + Arc::clone(arc) + } else { eprintln!("ModernBERT classifier not initialized"); + return default_result; + } + }; + + match classifier_arc.classify_text_with_probs(text) { + Ok((class_idx, confidence, probabilities)) => { + // Allocate memory for probabilities array + let prob_len = probabilities.len(); + let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32; + + ModernBertClassificationResultWithProbs { + class: class_idx as i32, + confidence, + probabilities: prob_ptr, + num_classes: prob_len as i32, + } + } + Err(e) => { + eprintln!("Error classifying text with probabilities using ModernBERT: {e}"); default_result } }