|
1 | 1 | //! Parallel LoRA processing engine |
2 | 2 | //! |
3 | 3 | //! Enables parallel execution of Intent||PII||Security classification tasks |
4 | | -//! Using thread-based parallelism instead of async/await |
| 4 | +//! Using rayon for efficient data parallelism |
5 | 5 |
|
6 | 6 | use crate::classifiers::lora::{ |
7 | 7 | intent_lora::{IntentLoRAClassifier, IntentResult}, |
8 | 8 | pii_lora::{PIILoRAClassifier, PIIResult}, |
9 | 9 | security_lora::{SecurityLoRAClassifier, SecurityResult}, |
10 | 10 | }; |
11 | | -use crate::core::{concurrency_error, ModelErrorType, UnifiedError}; |
| 11 | +use crate::core::{ModelErrorType, UnifiedError}; |
12 | 12 | use crate::model_error; |
13 | 13 | use candle_core::{Device, Result}; |
14 | | -use std::sync::{Arc, Mutex}; |
15 | | -use std::thread; |
| 14 | +use std::sync::Arc; |
16 | 15 |
|
17 | 16 | /// Parallel LoRA processing engine |
18 | 17 | pub struct ParallelLoRAEngine { |
@@ -77,97 +76,30 @@ impl ParallelLoRAEngine { |
77 | 76 | }) |
78 | 77 | } |
79 | 78 |
|
80 | | - /// Parallel classification across all three tasks |
| 79 | + /// Parallel classification across all three tasks using rayon |
| 80 | + /// |
| 81 | + /// # Performance |
| 82 | + /// - Uses rayon::join for parallel execution (no Arc<Mutex> overhead) |
| 83 | + /// - Simplified code: ~70 lines reduced to ~20 lines |
| 84 | + /// - No lock contention or synchronization overhead |
81 | 85 | pub fn parallel_classify(&self, texts: &[&str]) -> Result<ParallelResult> { |
82 | | - let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect(); |
83 | | - |
84 | | - // Create shared results |
85 | | - let intent_results = Arc::new(Mutex::new(Vec::new())); |
86 | | - let pii_results = Arc::new(Mutex::new(Vec::new())); |
87 | | - let security_results = Arc::new(Mutex::new(Vec::new())); |
88 | | - |
89 | | - let handles = vec![ |
90 | | - self.spawn_intent_task(texts_owned.clone(), Arc::clone(&intent_results)), |
91 | | - self.spawn_pii_task(texts_owned.clone(), Arc::clone(&pii_results)), |
92 | | - self.spawn_security_task(texts_owned, Arc::clone(&security_results)), |
93 | | - ]; |
94 | | - |
95 | | - // Wait for all threads to complete |
96 | | - for handle in handles { |
97 | | - handle.join().map_err(|_| { |
98 | | - let unified_err = concurrency_error( |
99 | | - "thread join", |
100 | | - "Failed to join parallel classification thread", |
101 | | - ); |
102 | | - candle_core::Error::from(unified_err) |
103 | | - })?; |
104 | | - } |
| 86 | + // Execute all three classifiers in parallel using rayon::join |
| 87 | + // Each task runs independently without shared mutable state |
| 88 | + let ((intent_results, pii_results), security_results) = rayon::join( |
| 89 | + || { |
| 90 | + rayon::join( |
| 91 | + || self.intent_classifier.batch_classify(texts), |
| 92 | + || self.pii_classifier.batch_detect(texts), |
| 93 | + ) |
| 94 | + }, |
| 95 | + || self.security_classifier.batch_detect(texts), |
| 96 | + ); |
105 | 97 |
|
| 98 | + // Propagate errors from any task |
106 | 99 | Ok(ParallelResult { |
107 | | - intent_results: Arc::try_unwrap(intent_results) |
108 | | - .unwrap() |
109 | | - .into_inner() |
110 | | - .unwrap(), |
111 | | - pii_results: Arc::try_unwrap(pii_results).unwrap().into_inner().unwrap(), |
112 | | - security_results: Arc::try_unwrap(security_results) |
113 | | - .unwrap() |
114 | | - .into_inner() |
115 | | - .unwrap(), |
116 | | - }) |
117 | | - } |
118 | | - |
119 | | - fn spawn_intent_task( |
120 | | - &self, |
121 | | - texts: Vec<String>, |
122 | | - results: Arc<Mutex<Vec<IntentResult>>>, |
123 | | - ) -> thread::JoinHandle<()> { |
124 | | - let classifier = Arc::clone(&self.intent_classifier); |
125 | | - thread::spawn(move || { |
126 | | - let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect(); |
127 | | - match classifier.batch_classify(&text_refs) { |
128 | | - Ok(task_results) => { |
129 | | - let mut guard = results.lock().unwrap(); |
130 | | - *guard = task_results; |
131 | | - } |
132 | | - Err(e) => { |
133 | | - eprintln!("Intent classification failed: {}", e); |
134 | | - } |
135 | | - } |
136 | | - }) |
137 | | - } |
138 | | - |
139 | | - fn spawn_pii_task( |
140 | | - &self, |
141 | | - texts: Vec<String>, |
142 | | - results: Arc<Mutex<Vec<PIIResult>>>, |
143 | | - ) -> thread::JoinHandle<()> { |
144 | | - let classifier = Arc::clone(&self.pii_classifier); |
145 | | - thread::spawn(move || { |
146 | | - let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect(); |
147 | | - if let Ok(task_results) = classifier.batch_detect(&text_refs) { |
148 | | - let mut guard = results.lock().unwrap(); |
149 | | - *guard = task_results; |
150 | | - } |
151 | | - }) |
152 | | - } |
153 | | - |
154 | | - fn spawn_security_task( |
155 | | - &self, |
156 | | - texts: Vec<String>, |
157 | | - results: Arc<Mutex<Vec<SecurityResult>>>, |
158 | | - ) -> thread::JoinHandle<()> { |
159 | | - let classifier = Arc::clone(&self.security_classifier); |
160 | | - thread::spawn(move || { |
161 | | - let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect(); |
162 | | - match classifier.batch_detect(&text_refs) { |
163 | | - Ok(task_results) => { |
164 | | - let mut guard = results.lock().unwrap(); |
165 | | - *guard = task_results; |
166 | | - } |
167 | | - Err(e) => { |
168 | | - eprintln!("Security classification failed: {}", e); |
169 | | - } |
170 | | - } |
| 100 | + intent_results: intent_results?, |
| 101 | + pii_results: pii_results?, |
| 102 | + security_results: security_results?, |
171 | 103 | }) |
172 | 104 | } |
173 | 105 | } |
|
0 commit comments