Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions candle-binding/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ tracing = "0.1.37"
libc = "0.2.147"
lazy_static = "1.4.0"
rand = "0.8.5"
# Performance optimization: parallel processing and lock-free initialization
rayon = "1.8"
once_cell = "1.19"

[dev-dependencies]
rstest = "0.18"
Expand Down
13 changes: 10 additions & 3 deletions candle-binding/src/classifiers/lora/intent_lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,18 @@ impl IntentLoRAClassifier {
})
}

/// Parallel classification for multiple texts
/// Parallel classification for multiple texts using rayon
///
/// # Performance
/// - Uses rayon for parallel processing across available CPU cores
/// - Efficient for batch sizes > 10
/// - No lock contention during inference
pub fn parallel_classify(&self, texts: &[&str]) -> Result<Vec<IntentResult>> {
// Process each text using real model inference
use rayon::prelude::*;

// Process each text using real model inference in parallel
texts
.iter()
.par_iter()
.map(|text| self.classify_intent(text))
.collect()
}
Expand Down
2 changes: 2 additions & 0 deletions candle-binding/src/classifiers/lora/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub use security_lora::*;
#[cfg(test)]
pub mod intent_lora_test;
#[cfg(test)]
pub mod parallel_engine_test;
#[cfg(test)]
pub mod pii_lora_test;
#[cfg(test)]
pub mod security_lora_test;
Expand Down
116 changes: 24 additions & 92 deletions candle-binding/src/classifiers/lora/parallel_engine.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
//! Parallel LoRA processing engine
//!
//! Enables parallel execution of Intent||PII||Security classification tasks
//! Using thread-based parallelism instead of async/await
//! Using rayon for efficient data parallelism

use crate::classifiers::lora::{
intent_lora::{IntentLoRAClassifier, IntentResult},
pii_lora::{PIILoRAClassifier, PIIResult},
security_lora::{SecurityLoRAClassifier, SecurityResult},
};
use crate::core::{concurrency_error, ModelErrorType, UnifiedError};
use crate::core::{ModelErrorType, UnifiedError};
use crate::model_error;
use candle_core::{Device, Result};
use std::sync::{Arc, Mutex};
use std::thread;
use std::sync::Arc;

/// Parallel LoRA processing engine
pub struct ParallelLoRAEngine {
Expand Down Expand Up @@ -77,97 +76,30 @@ impl ParallelLoRAEngine {
})
}

/// Parallel classification across all three tasks
/// Parallel classification across all three tasks using rayon
///
/// # Performance
/// - Uses rayon::join for parallel execution (no Arc<Mutex> overhead)
/// - Simplified code: ~70 lines reduced to ~20 lines
/// - No lock contention or synchronization overhead
pub fn parallel_classify(&self, texts: &[&str]) -> Result<ParallelResult> {
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();

// Create shared results
let intent_results = Arc::new(Mutex::new(Vec::new()));
let pii_results = Arc::new(Mutex::new(Vec::new()));
let security_results = Arc::new(Mutex::new(Vec::new()));

let handles = vec![
self.spawn_intent_task(texts_owned.clone(), Arc::clone(&intent_results)),
self.spawn_pii_task(texts_owned.clone(), Arc::clone(&pii_results)),
self.spawn_security_task(texts_owned, Arc::clone(&security_results)),
];

// Wait for all threads to complete
for handle in handles {
handle.join().map_err(|_| {
let unified_err = concurrency_error(
"thread join",
"Failed to join parallel classification thread",
);
candle_core::Error::from(unified_err)
})?;
}
// Execute all three classifiers in parallel using rayon::join
// Each task runs independently without shared mutable state
let ((intent_results, pii_results), security_results) = rayon::join(
|| {
rayon::join(
|| self.intent_classifier.batch_classify(texts),
|| self.pii_classifier.batch_detect(texts),
)
},
|| self.security_classifier.batch_detect(texts),
);

// Propagate errors from any task
Ok(ParallelResult {
intent_results: Arc::try_unwrap(intent_results)
.unwrap()
.into_inner()
.unwrap(),
pii_results: Arc::try_unwrap(pii_results).unwrap().into_inner().unwrap(),
security_results: Arc::try_unwrap(security_results)
.unwrap()
.into_inner()
.unwrap(),
})
}

fn spawn_intent_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<IntentResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.intent_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
match classifier.batch_classify(&text_refs) {
Ok(task_results) => {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
Err(e) => {
eprintln!("Intent classification failed: {}", e);
}
}
})
}

fn spawn_pii_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<PIIResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.pii_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
if let Ok(task_results) = classifier.batch_detect(&text_refs) {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
})
}

fn spawn_security_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<SecurityResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.security_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
match classifier.batch_detect(&text_refs) {
Ok(task_results) => {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
Err(e) => {
eprintln!("Security classification failed: {}", e);
}
}
intent_results: intent_results?,
pii_results: pii_results?,
security_results: security_results?,
})
}
}
Expand Down
Loading
Loading