Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ check-go-mod-tidy:
test-binding: rust
@echo "Running Go tests with static library..."
@export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \
cd candle-binding && CGO_ENABLED=1 go test -v
cd candle-binding && CGO_ENABLED=1 go test -v -race

# Test with the candle-binding library
test-category-classifier: rust
Expand Down
96 changes: 96 additions & 0 deletions candle-binding/semantic-router_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package candle_binding

import (
"context"
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -406,6 +409,99 @@ func TestModernBERTClassifiers(t *testing.T) {
})
}

func TestModernBertClassifier_ConcurrentClassificationSafety(t *testing.T) {
// init
if err := InitModernBertClassifier(CategoryClassifierModelPath, true); err != nil {
t.Skipf("ModernBERT classifier not available: %v", err)
}

texts := []string{
"This is a test sentence for classification",
"Another example text to classify with ModernBERT",
"The quick brown fox jumps over the lazy dog",
"Machine learning models are becoming more efficient",
"Natural language processing is a fascinating field",
}

// Baseline (single-threaded)
baseline := make(map[string]ClassResult, len(texts))
for _, txt := range texts {
res, err := ClassifyModernBertText(txt)
if err != nil {
t.Fatalf("baseline call failed for %q: %v", txt, err)
}
baseline[txt] = res
}

const numGoroutines = 10
const iterationsPerGoroutine = 5

var wg sync.WaitGroup
errCh := make(chan error, numGoroutines*iterationsPerGoroutine)
var total int64

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for g := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := range iterationsPerGoroutine {
select {
case <-ctx.Done():
return
default:
}

txt := texts[(id+i)%len(texts)]
res, err := ClassifyModernBertText(txt)
if err != nil {
errCh <- fmt.Errorf("gor %d iter %d classify error: %v", id, i, err)
cancel() // stop early
return
}

// Strict: class must match baseline
base := baseline[txt]
if res.Class != base.Class {
errCh <- fmt.Errorf("gor %d iter %d: class mismatch for %q: got %d expected %d", id, i, txt, res.Class, base.Class)
cancel()
return
}

// Allow small FP differences
if math.Abs(float64(res.Confidence)-float64(base.Confidence)) > 0.05 {
errCh <- fmt.Errorf("gor %d iter %d: confidence mismatch for %q: got %f expected %f", id, i, txt, res.Confidence, base.Confidence)
cancel()
return
}

atomic.AddInt64(&total, 1)
}
}(g)
}

wg.Wait()
close(errCh)

errs := 0
for e := range errCh {
t.Error(e)
errs++
}
if errs > 0 {
t.Fatalf("concurrency test failed with %d errors", errs)
}

expected := int64(numGoroutines * iterationsPerGoroutine)
if total != expected {
t.Fatalf("expected %d successful results, got %d", expected, total)
}

t.Logf("concurrent test OK: goroutines=%d iterations=%d", numGoroutines, iterationsPerGoroutine)
}

// TestModernBERTPIITokenClassification tests the PII token classification functionality
func TestModernBERTPIITokenClassification(t *testing.T) {
// Test data with various PII entities
Expand Down
72 changes: 41 additions & 31 deletions candle-binding/src/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ pub struct ModernBertClassifier {
}

lazy_static::lazy_static! {
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<Arc<ModernBertClassifier>>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_PII_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
static ref MODERNBERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
}
Expand Down Expand Up @@ -830,7 +830,7 @@ pub extern "C" fn init_modernbert_classifier(model_id: *const c_char, use_cpu: b
match ModernBertClassifier::new(model_id, use_cpu) {
Ok(classifier) => {
let mut bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
*bert_opt = Some(classifier);
*bert_opt = Some(Arc::new(classifier));
true
}
Err(e) => {
Expand Down Expand Up @@ -930,18 +930,23 @@ pub extern "C" fn classify_modernbert_text(text: *const c_char) -> ModernBertCla
}
};

let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*bert_opt {
Some(classifier) => match classifier.classify_text(text) {
Ok((class_idx, confidence)) => ModernBertClassificationResult {
class: class_idx as i32,
confidence,
},
Err(e) => {
eprintln!("Error classifying text with ModernBERT: {e}");
default_result
let classifier_arc_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*classifier_arc_opt {
Some(classifier_arc) => {
let classifier = classifier_arc.clone();
drop(classifier_arc_opt);

match classifier.classify_text(text) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still under the lock

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, let me try to fix it

Ok((class_idx, confidence)) => ModernBertClassificationResult {
class: class_idx as i32,
confidence,
},
Err(e) => {
eprintln!("Error classifying text with ModernBERT: {e}");
default_result
}
}
},
}
None => {
eprintln!("ModernBERT classifier not initialized");
default_result
Expand All @@ -968,26 +973,31 @@ pub extern "C" fn classify_modernbert_text_with_probabilities(
}
};

let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*bert_opt {
Some(classifier) => match classifier.classify_text_with_probs(text) {
Ok((class_idx, confidence, probabilities)) => {
// Allocate memory for probabilities array
let prob_len = probabilities.len();
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;

ModernBertClassificationResultWithProbs {
class: class_idx as i32,
confidence,
probabilities: prob_ptr,
num_classes: prob_len as i32,
let classifier_arc_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
match &*classifier_arc_opt {
Some(classifier_arc) => {
let classifier = classifier_arc.clone();
drop(classifier_arc_opt);

match classifier.classify_text_with_probs(text) {
Ok((class_idx, confidence, probabilities)) => {
// Allocate memory for probabilities array
let prob_len = probabilities.len();
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;

ModernBertClassificationResultWithProbs {
class: class_idx as i32,
confidence,
probabilities: prob_ptr,
num_classes: prob_len as i32,
}
}
Err(e) => {
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
default_result
}
}
Err(e) => {
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
default_result
}
},
}
None => {
eprintln!("ModernBERT classifier not initialized");
default_result
Expand Down
Loading