Skip to content

Commit 2b72f27

Browse files
OneZero-Yrootfs
authored andcommitted
fix: Fix duplicate UNIFIED_CLASSIFIER definition and optimize lock contention (#516)
- Remove duplicate UNIFIED_CLASSIFIER global state - Optimize PARALLEL_LORA_ENGINE lock contention by using Arc clone Signed-off-by: OneZero-Y <[email protected]>
1 parent a81c29c commit 2b72f27

File tree

2 files changed

+74
-75
lines changed

2 files changed

+74
-75
lines changed

candle-binding/src/ffi/classify.rs

Lines changed: 67 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ use crate::ffi::memory::{
1111
allocate_pii_result_array, allocate_security_result_array,
1212
};
1313
use crate::ffi::types::*;
14-
use crate::BertClassifier;
15-
use lazy_static::lazy_static;
16-
use std::ffi::{c_char, CStr};
17-
use std::sync::{Arc, Mutex};
18-
19-
use crate::classifiers::unified::DualPathUnifiedClassifier;
2014
use crate::model_architectures::traditional::bert::{
2115
TRADITIONAL_BERT_CLASSIFIER, TRADITIONAL_BERT_TOKEN_CLASSIFIER,
2216
};
@@ -25,9 +19,13 @@ use crate::model_architectures::traditional::modernbert::{
2519
TRADITIONAL_MODERNBERT_PII_CLASSIFIER, TRADITIONAL_MODERNBERT_TOKEN_CLASSIFIER,
2620
};
2721
use crate::model_architectures::traits::TaskType;
22+
use crate::BertClassifier;
23+
use lazy_static::lazy_static;
24+
use std::ffi::{c_char, CStr};
25+
use std::sync::{Arc, Mutex};
2826
extern crate lazy_static;
2927

30-
use crate::ffi::init::PARALLEL_LORA_ENGINE;
28+
use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
3129

3230
/// Load id2label mapping from model config.json file
3331
/// Returns HashMap mapping class index (as string) to label name
@@ -42,8 +40,9 @@ pub fn load_id2label_from_config(
4240

4341
// Global state for classification using dual-path architecture
4442
lazy_static! {
45-
static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
46-
// Legacy classifiers for backward compatibility
43+
// NOTE: UNIFIED_CLASSIFIER is defined in ffi/init.rs and re-exported
44+
// We import it here to avoid duplicate definitions
45+
// Legacy classifiers for backward compatibility (still needed for old API paths)
4746
static ref BERT_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
4847
static ref BERT_PII_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
4948
static ref BERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
@@ -654,71 +653,70 @@ pub extern "C" fn classify_batch_with_lora(
654653
}
655654

656655
let start_time = std::time::Instant::now();
657-
let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
658-
match engine_guard.as_ref() {
659-
Some(engine) => {
660-
let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect();
661-
match engine.parallel_classify(&text_refs) {
662-
Ok(parallel_result) => {
663-
let _processing_time_ms = start_time.elapsed().as_millis() as f32;
664-
665-
// Allocate C arrays for LoRA results
666-
let intent_results_ptr =
667-
unsafe { allocate_lora_intent_array(&parallel_result.intent_results) };
668-
let pii_results_ptr =
669-
unsafe { allocate_lora_pii_array(&parallel_result.pii_results) };
670-
let security_results_ptr =
671-
unsafe { allocate_lora_security_array(&parallel_result.security_results) };
672656

673-
LoRABatchResult {
674-
intent_results: intent_results_ptr,
675-
pii_results: pii_results_ptr,
676-
security_results: security_results_ptr,
677-
batch_size: texts_count as i32,
678-
avg_confidence: {
679-
let mut total_confidence = 0.0f32;
680-
let mut count = 0;
681-
682-
// Sum intent confidences
683-
for intent in &parallel_result.intent_results {
684-
total_confidence += intent.confidence;
685-
count += 1;
686-
}
687-
688-
// Sum PII confidences
689-
for pii in &parallel_result.pii_results {
690-
total_confidence += pii.confidence;
691-
count += 1;
692-
}
693-
694-
// Sum security confidences
695-
for security in &parallel_result.security_results {
696-
total_confidence += security.confidence;
697-
count += 1;
698-
}
699-
700-
if count > 0 {
701-
total_confidence / count as f32
702-
} else {
703-
0.0
704-
}
705-
},
657+
// Optimization: Clone Arc to minimize lock holding time
658+
// Lock is only held during the clone operation (~nanoseconds), not during inference
659+
let engine: Arc<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine> = {
660+
let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
661+
match engine_guard.as_ref() {
662+
Some(e) => e.clone(),
663+
None => {
664+
eprintln!("PARALLEL_LORA_ENGINE not initialized");
665+
return default_result;
666+
}
667+
}
668+
}; // Lock is released here immediately after clone
669+
670+
// Now perform inference without holding the lock (allows concurrent requests)
671+
let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect();
672+
match engine.parallel_classify(&text_refs) {
673+
Ok(parallel_result) => {
674+
let _processing_time_ms = start_time.elapsed().as_millis() as f32;
675+
676+
// Allocate C arrays for LoRA results
677+
let intent_results_ptr =
678+
unsafe { allocate_lora_intent_array(&parallel_result.intent_results) };
679+
let pii_results_ptr = unsafe { allocate_lora_pii_array(&parallel_result.pii_results) };
680+
let security_results_ptr =
681+
unsafe { allocate_lora_security_array(&parallel_result.security_results) };
682+
683+
LoRABatchResult {
684+
intent_results: intent_results_ptr,
685+
pii_results: pii_results_ptr,
686+
security_results: security_results_ptr,
687+
batch_size: texts_count as i32,
688+
avg_confidence: {
689+
let mut total_confidence = 0.0f32;
690+
let mut count = 0;
691+
692+
// Sum intent confidences
693+
for intent in &parallel_result.intent_results {
694+
total_confidence += intent.confidence;
695+
count += 1;
706696
}
707-
}
708-
Err(e) => {
709-
println!("LoRA parallel classification failed: {}", e);
710-
LoRABatchResult {
711-
intent_results: std::ptr::null_mut(),
712-
pii_results: std::ptr::null_mut(),
713-
security_results: std::ptr::null_mut(),
714-
batch_size: 0,
715-
avg_confidence: 0.0,
697+
698+
// Sum PII confidences
699+
for pii in &parallel_result.pii_results {
700+
total_confidence += pii.confidence;
701+
count += 1;
716702
}
717-
}
703+
704+
// Sum security confidences
705+
for security in &parallel_result.security_results {
706+
total_confidence += security.confidence;
707+
count += 1;
708+
}
709+
710+
if count > 0 {
711+
total_confidence / count as f32
712+
} else {
713+
0.0
714+
}
715+
},
718716
}
719717
}
720-
None => {
721-
println!("ParallelLoRAEngine not initialized - call init function first");
718+
Err(e) => {
719+
println!("LoRA parallel classification failed: {}", e);
722720
LoRABatchResult {
723721
intent_results: std::ptr::null_mut(),
724722
pii_results: std::ptr::null_mut(),

candle-binding/src/ffi/init.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ lazy_static! {
1717
static ref BERT_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
1818
static ref BERT_PII_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
1919
static ref BERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
20-
// Unified classifier for dual-path architecture
21-
static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::unified::DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
22-
// Parallel LoRA engine for high-performance classification
23-
pub static ref PARALLEL_LORA_ENGINE: Arc<Mutex<Option<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine>>> = Arc::new(Mutex::new(None));
20+
// Unified classifier for dual-path architecture (exported for use in classify.rs)
21+
pub static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::unified::DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
22+
// Parallel LoRA engine for high-performance classification (primary path for LoRA models)
23+
// Wrapped in Arc for cheap cloning and concurrent access
24+
pub static ref PARALLEL_LORA_ENGINE: Arc<Mutex<Option<Arc<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine>>>> = Arc::new(Mutex::new(None));
2425
// LoRA token classifier for token-level classification
2526
pub static ref LORA_TOKEN_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::lora::token_lora::LoRATokenClassifier>>> = Arc::new(Mutex::new(None));
2627
}
@@ -719,9 +720,9 @@ pub extern "C" fn init_lora_unified_classifier(
719720
use_cpu,
720721
) {
721722
Ok(engine) => {
722-
// Store in global static variable
723+
// Store in global static variable (wrapped in Arc for efficient cloning)
723724
let mut engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
724-
*engine_guard = Some(engine);
725+
*engine_guard = Some(Arc::new(engine));
725726
true
726727
}
727728
Err(e) => {

0 commit comments

Comments
 (0)