Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ check-go-mod-tidy:
test-binding: rust
@echo "Running Go tests with static library..."
@export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \
cd candle-binding && CGO_ENABLED=1 go test -v
cd candle-binding && CGO_ENABLED=1 go test -v -race

# Test with the candle-binding library
test-category-classifier: rust
Expand Down
96 changes: 96 additions & 0 deletions candle-binding/semantic-router_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package candle_binding

import (
"context"
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -476,6 +479,99 @@ func TestModernBERTClassifiers(t *testing.T) {
})
}

func TestModernBertClassifier_ConcurrentClassificationSafety(t *testing.T) {
// init
if err := InitModernBertClassifier(CategoryClassifierModelPath, true); err != nil {
t.Skipf("ModernBERT classifier not available: %v", err)
}

texts := []string{
"This is a test sentence for classification",
"Another example text to classify with ModernBERT",
"The quick brown fox jumps over the lazy dog",
"Machine learning models are becoming more efficient",
"Natural language processing is a fascinating field",
}

// Baseline (single-threaded)
baseline := make(map[string]ClassResult, len(texts))
for _, txt := range texts {
res, err := ClassifyModernBertText(txt)
if err != nil {
t.Fatalf("baseline call failed for %q: %v", txt, err)
}
baseline[txt] = res
}

const numGoroutines = 10
const iterationsPerGoroutine = 5

var wg sync.WaitGroup
errCh := make(chan error, numGoroutines*iterationsPerGoroutine)
var total int64

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for g := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := range iterationsPerGoroutine {
select {
case <-ctx.Done():
return
default:
}

txt := texts[(id+i)%len(texts)]
res, err := ClassifyModernBertText(txt)
if err != nil {
errCh <- fmt.Errorf("gor %d iter %d classify error: %v", id, i, err)
cancel() // stop early
return
}

// Strict: class must match baseline
base := baseline[txt]
if res.Class != base.Class {
errCh <- fmt.Errorf("gor %d iter %d: class mismatch for %q: got %d expected %d", id, i, txt, res.Class, base.Class)
cancel()
return
}

// Allow small FP differences
if math.Abs(float64(res.Confidence)-float64(base.Confidence)) > 0.05 {
errCh <- fmt.Errorf("gor %d iter %d: confidence mismatch for %q: got %f expected %f", id, i, txt, res.Confidence, base.Confidence)
cancel()
return
}

atomic.AddInt64(&total, 1)
}
}(g)
}

wg.Wait()
close(errCh)

errs := 0
for e := range errCh {
t.Error(e)
errs++
}
if errs > 0 {
t.Fatalf("concurrency test failed with %d errors", errs)
}

expected := int64(numGoroutines * iterationsPerGoroutine)
if total != expected {
t.Fatalf("expected %d successful results, got %d", expected, total)
}

t.Logf("concurrent test OK: goroutines=%d iterations=%d", numGoroutines, iterationsPerGoroutine)
}

// TestModernBERTPIITokenClassification tests the PII token classification functionality
func TestModernBERTPIITokenClassification(t *testing.T) {
// Test data with various PII entities
Expand Down
78 changes: 42 additions & 36 deletions candle-binding/src/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ pub struct ModernBertClassifier {
}

lazy_static::lazy_static! {
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<Arc<ModernBertClassifier>>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_PII_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
}
Expand Down Expand Up @@ -830,7 +830,7 @@ pub extern "C" fn init_modernbert_classifier(model_id: *const c_char, use_cpu: b
match ModernBertClassifier::new(model_id, use_cpu) {
Ok(classifier) => {
let mut bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
*bert_opt = Some(classifier);
*bert_opt = Some(Arc::new(classifier));
true
}
Err(e) => {
Expand Down Expand Up @@ -930,20 +930,23 @@ pub extern "C" fn classify_modernbert_text(text: *const c_char) -> ModernBertCla
}
};

let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*bert_opt {
Some(classifier) => match classifier.classify_text(text) {
Ok((class_idx, confidence)) => ModernBertClassificationResult {
class: class_idx as i32,
confidence,
},
Err(e) => {
eprintln!("Error classifying text with ModernBERT: {e}");
default_result
}
},
None => {
let classifier_arc = {
let guard = MODERNBERT_CLASSIFIER.lock().unwrap();
if let Some(arc) = guard.as_ref() {
Arc::clone(arc)
} else {
eprintln!("ModernBERT classifier not initialized");
return default_result;
}
};

match classifier_arc.classify_text(text) {
Ok((class_idx, confidence)) => ModernBertClassificationResult {
class: class_idx as i32,
confidence,
},
Err(e) => {
eprintln!("Error classifying text with ModernBERT: {e}");
default_result
}
}
Expand All @@ -968,28 +971,31 @@ pub extern "C" fn classify_modernbert_text_with_probabilities(
}
};

let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*bert_opt {
Some(classifier) => match classifier.classify_text_with_probs(text) {
Ok((class_idx, confidence, probabilities)) => {
// Allocate memory for probabilities array
let prob_len = probabilities.len();
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;

ModernBertClassificationResultWithProbs {
class: class_idx as i32,
confidence,
probabilities: prob_ptr,
num_classes: prob_len as i32,
}
}
Err(e) => {
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
default_result
}
},
None => {
let classifier_arc = {
let guard = MODERNBERT_CLASSIFIER.lock().unwrap();
if let Some(arc) = guard.as_ref() {
Arc::clone(arc)
} else {
eprintln!("ModernBERT classifier not initialized");
return default_result;
}
};

match classifier_arc.classify_text_with_probs(text) {
Ok((class_idx, confidence, probabilities)) => {
// Allocate memory for probabilities array
let prob_len = probabilities.len();
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;

ModernBertClassificationResultWithProbs {
class: class_idx as i32,
confidence,
probabilities: prob_ptr,
num_classes: prob_len as i32,
}
}
Err(e) => {
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
default_result
}
}
Expand Down
Loading