Skip to content

Commit d5e3229

Browse files
cryo-zdrootfs
andauthored
perf: enable concurrent classification via Arc+clone (#127)
* perf: Allow Concurrent Text Classification Signed-off-by: cryo <[email protected]> * chore: add -race option for candle-binding test Signed-off-by: cryo <[email protected]> * fix: release mutex early Signed-off-by: cryo <[email protected]> --------- Signed-off-by: cryo <[email protected]> Co-authored-by: Huamin Chen <[email protected]>
1 parent d6e2e77 commit d5e3229

File tree

3 files changed

+139
-37
lines changed

3 files changed

+139
-37
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ check-go-mod-tidy:
141141
test-binding: rust
142142
@echo "Running Go tests with static library..."
143143
@export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \
144-
cd candle-binding && CGO_ENABLED=1 go test -v
144+
cd candle-binding && CGO_ENABLED=1 go test -v -race
145145

146146
# Test with the candle-binding library
147147
test-category-classifier: rust

candle-binding/semantic-router_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package candle_binding
22

33
import (
4+
"context"
5+
"fmt"
46
"math"
57
"runtime"
68
"strings"
79
"sync"
10+
"sync/atomic"
811
"testing"
912
"time"
1013
)
@@ -476,6 +479,99 @@ func TestModernBERTClassifiers(t *testing.T) {
476479
})
477480
}
478481

482+
func TestModernBertClassifier_ConcurrentClassificationSafety(t *testing.T) {
483+
// init
484+
if err := InitModernBertClassifier(CategoryClassifierModelPath, true); err != nil {
485+
t.Skipf("ModernBERT classifier not available: %v", err)
486+
}
487+
488+
texts := []string{
489+
"This is a test sentence for classification",
490+
"Another example text to classify with ModernBERT",
491+
"The quick brown fox jumps over the lazy dog",
492+
"Machine learning models are becoming more efficient",
493+
"Natural language processing is a fascinating field",
494+
}
495+
496+
// Baseline (single-threaded)
497+
baseline := make(map[string]ClassResult, len(texts))
498+
for _, txt := range texts {
499+
res, err := ClassifyModernBertText(txt)
500+
if err != nil {
501+
t.Fatalf("baseline call failed for %q: %v", txt, err)
502+
}
503+
baseline[txt] = res
504+
}
505+
506+
const numGoroutines = 10
507+
const iterationsPerGoroutine = 5
508+
509+
var wg sync.WaitGroup
510+
errCh := make(chan error, numGoroutines*iterationsPerGoroutine)
511+
var total int64
512+
513+
ctx, cancel := context.WithCancel(context.Background())
514+
defer cancel()
515+
516+
for g := range numGoroutines {
517+
wg.Add(1)
518+
go func(id int) {
519+
defer wg.Done()
520+
for i := range iterationsPerGoroutine {
521+
select {
522+
case <-ctx.Done():
523+
return
524+
default:
525+
}
526+
527+
txt := texts[(id+i)%len(texts)]
528+
res, err := ClassifyModernBertText(txt)
529+
if err != nil {
530+
errCh <- fmt.Errorf("gor %d iter %d classify error: %v", id, i, err)
531+
cancel() // stop early
532+
return
533+
}
534+
535+
// Strict: class must match baseline
536+
base := baseline[txt]
537+
if res.Class != base.Class {
538+
errCh <- fmt.Errorf("gor %d iter %d: class mismatch for %q: got %d expected %d", id, i, txt, res.Class, base.Class)
539+
cancel()
540+
return
541+
}
542+
543+
// Allow small FP differences
544+
if math.Abs(float64(res.Confidence)-float64(base.Confidence)) > 0.05 {
545+
errCh <- fmt.Errorf("gor %d iter %d: confidence mismatch for %q: got %f expected %f", id, i, txt, res.Confidence, base.Confidence)
546+
cancel()
547+
return
548+
}
549+
550+
atomic.AddInt64(&total, 1)
551+
}
552+
}(g)
553+
}
554+
555+
wg.Wait()
556+
close(errCh)
557+
558+
errs := 0
559+
for e := range errCh {
560+
t.Error(e)
561+
errs++
562+
}
563+
if errs > 0 {
564+
t.Fatalf("concurrency test failed with %d errors", errs)
565+
}
566+
567+
expected := int64(numGoroutines * iterationsPerGoroutine)
568+
if total != expected {
569+
t.Fatalf("expected %d successful results, got %d", expected, total)
570+
}
571+
572+
t.Logf("concurrent test OK: goroutines=%d iterations=%d", numGoroutines, iterationsPerGoroutine)
573+
}
574+
479575
// TestModernBERTPIITokenClassification tests the PII token classification functionality
480576
func TestModernBERTPIITokenClassification(t *testing.T) {
481577
// Test data with various PII entities

candle-binding/src/modernbert.rs

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ pub struct ModernBertClassifier {
268268
}
269269

270270
lazy_static::lazy_static! {
271-
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
271+
static ref MODERNBERT_CLASSIFIER: Arc<Mutex<Option<Arc<ModernBertClassifier>>>> = Arc::new(Mutex::new(None));
272272
static ref MODERNBERT_PII_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
273273
static ref MODERNBERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<ModernBertClassifier>>> = Arc::new(Mutex::new(None));
274274
}
@@ -830,7 +830,7 @@ pub extern "C" fn init_modernbert_classifier(model_id: *const c_char, use_cpu: b
830830
match ModernBertClassifier::new(model_id, use_cpu) {
831831
Ok(classifier) => {
832832
let mut bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
833-
*bert_opt = Some(classifier);
833+
*bert_opt = Some(Arc::new(classifier));
834834
true
835835
}
836836
Err(e) => {
@@ -930,20 +930,23 @@ pub extern "C" fn classify_modernbert_text(text: *const c_char) -> ModernBertCla
930930
}
931931
};
932932

933-
let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
934-
match &*bert_opt {
935-
Some(classifier) => match classifier.classify_text(text) {
936-
Ok((class_idx, confidence)) => ModernBertClassificationResult {
937-
class: class_idx as i32,
938-
confidence,
939-
},
940-
Err(e) => {
941-
eprintln!("Error classifying text with ModernBERT: {e}");
942-
default_result
943-
}
944-
},
945-
None => {
933+
let classifier_arc = {
934+
let guard = MODERNBERT_CLASSIFIER.lock().unwrap();
935+
if let Some(arc) = guard.as_ref() {
936+
Arc::clone(arc)
937+
} else {
946938
eprintln!("ModernBERT classifier not initialized");
939+
return default_result;
940+
}
941+
};
942+
943+
match classifier_arc.classify_text(text) {
944+
Ok((class_idx, confidence)) => ModernBertClassificationResult {
945+
class: class_idx as i32,
946+
confidence,
947+
},
948+
Err(e) => {
949+
eprintln!("Error classifying text with ModernBERT: {e}");
947950
default_result
948951
}
949952
}
@@ -968,28 +971,31 @@ pub extern "C" fn classify_modernbert_text_with_probabilities(
968971
}
969972
};
970973

971-
let bert_opt = MODERNBERT_CLASSIFIER.lock().unwrap();
972-
match &*bert_opt {
973-
Some(classifier) => match classifier.classify_text_with_probs(text) {
974-
Ok((class_idx, confidence, probabilities)) => {
975-
// Allocate memory for probabilities array
976-
let prob_len = probabilities.len();
977-
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;
978-
979-
ModernBertClassificationResultWithProbs {
980-
class: class_idx as i32,
981-
confidence,
982-
probabilities: prob_ptr,
983-
num_classes: prob_len as i32,
984-
}
985-
}
986-
Err(e) => {
987-
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
988-
default_result
989-
}
990-
},
991-
None => {
974+
let classifier_arc = {
975+
let guard = MODERNBERT_CLASSIFIER.lock().unwrap();
976+
if let Some(arc) = guard.as_ref() {
977+
Arc::clone(arc)
978+
} else {
992979
eprintln!("ModernBERT classifier not initialized");
980+
return default_result;
981+
}
982+
};
983+
984+
match classifier_arc.classify_text_with_probs(text) {
985+
Ok((class_idx, confidence, probabilities)) => {
986+
// Allocate memory for probabilities array
987+
let prob_len = probabilities.len();
988+
let prob_ptr = Box::into_raw(probabilities.into_boxed_slice()) as *mut f32;
989+
990+
ModernBertClassificationResultWithProbs {
991+
class: class_idx as i32,
992+
confidence,
993+
probabilities: prob_ptr,
994+
num_classes: prob_len as i32,
995+
}
996+
}
997+
Err(e) => {
998+
eprintln!("Error classifying text with probabilities using ModernBERT: {e}");
993999
default_result
9941000
}
9951001
}

0 commit comments

Comments
 (0)