Skip to content

Commit 3697826

Browse files
committed
refactor: Implement modular candle-binding architecture
 - Restructure codebase into modular layers (core/, ffi/, model_architectures/, classifiers/) - Add unified error handling and configuration loading systems - Implement dual-path architecture for traditional and LoRA models - Add comprehensive FFI layer with memory safety Maintains backward compatibility while enabling future model integrations. Signed-off-by: OneZero-Y <[email protected]>
1 parent e75fc0f commit 3697826

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+14353
-5348
lines changed

candle-binding/src/bert_official.rs

Lines changed: 0 additions & 441 deletions
This file was deleted.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
//! Intent classification with LoRA adapters
2+
//!
3+
//! High-performance intent classification using real model inference
4+
5+
use crate::core::{processing_errors, ModelErrorType, UnifiedError};
6+
use crate::model_architectures::lora::bert_lora::HighPerformanceBertClassifier;
7+
use crate::model_error;
8+
use candle_core::Result;
9+
use std::time::Instant;
10+
11+
/// Intent classifier with real model inference (merged LoRA models)
12+
pub struct IntentLoRAClassifier {
13+
/// High-performance BERT classifier for intent classification
14+
bert_classifier: HighPerformanceBertClassifier,
15+
/// Confidence threshold for predictions
16+
confidence_threshold: f32,
17+
/// Intent labels mapping
18+
intent_labels: Vec<String>,
19+
/// Model path for reference
20+
model_path: String,
21+
}
22+
23+
/// Intent classification result
24+
#[derive(Debug, Clone)]
25+
pub struct IntentResult {
26+
pub intent: String,
27+
pub confidence: f32,
28+
pub processing_time_ms: u64,
29+
}
30+
31+
impl IntentLoRAClassifier {
32+
/// Create new intent classifier using real model inference
33+
pub fn new(model_path: &str, use_cpu: bool) -> Result<Self> {
34+
// Load labels from model config
35+
let intent_labels = Self::load_labels_from_config(model_path)?;
36+
let num_classes = intent_labels.len();
37+
38+
// Load the high-performance BERT classifier for merged LoRA models
39+
let classifier = HighPerformanceBertClassifier::new(model_path, num_classes, use_cpu)
40+
.map_err(|e| {
41+
let unified_err = model_error!(
42+
ModelErrorType::LoRA,
43+
"intent classifier creation",
44+
format!("Failed to create BERT classifier: {}", e),
45+
model_path
46+
);
47+
candle_core::Error::from(unified_err)
48+
})?;
49+
50+
Ok(Self {
51+
bert_classifier: classifier,
52+
confidence_threshold: 0.7,
53+
intent_labels,
54+
model_path: model_path.to_string(),
55+
})
56+
}
57+
58+
/// Load intent labels from model config.json using unified config loader
59+
fn load_labels_from_config(model_path: &str) -> Result<Vec<String>> {
60+
use crate::core::config_loader;
61+
62+
match config_loader::load_intent_labels(model_path) {
63+
Ok(result) => Ok(result),
64+
Err(unified_err) => Err(candle_core::Error::from(unified_err)),
65+
}
66+
}
67+
68+
/// Classify intent using real model inference
69+
pub fn classify_intent(&self, text: &str) -> Result<IntentResult> {
70+
let start_time = Instant::now();
71+
72+
// Use real BERT model for classification
73+
let (predicted_class, confidence) =
74+
self.bert_classifier.classify_text(text).map_err(|e| {
75+
let unified_err = model_error!(
76+
ModelErrorType::LoRA,
77+
"intent classification",
78+
format!("Classification failed: {}", e),
79+
text
80+
);
81+
candle_core::Error::from(unified_err)
82+
})?;
83+
84+
// Map class index to intent label
85+
let intent = if predicted_class < self.intent_labels.len() {
86+
self.intent_labels[predicted_class].clone()
87+
} else {
88+
format!("UNKNOWN_{}", predicted_class)
89+
};
90+
91+
let processing_time = start_time.elapsed().as_millis() as u64;
92+
93+
Ok(IntentResult {
94+
intent,
95+
confidence,
96+
processing_time_ms: processing_time,
97+
})
98+
}
99+
100+
/// Parallel classification for multiple texts
101+
pub fn parallel_classify(&self, texts: &[&str]) -> Result<Vec<IntentResult>> {
102+
// Process each text using real model inference
103+
texts
104+
.iter()
105+
.map(|text| self.classify_intent(text))
106+
.collect()
107+
}
108+
109+
/// Batch classification for multiple texts (optimized)
110+
pub fn batch_classify(&self, texts: &[&str]) -> Result<Vec<IntentResult>> {
111+
let start_time = Instant::now();
112+
113+
// Use BERT's batch processing capability
114+
let batch_results = self.bert_classifier.classify_batch(texts).map_err(|e| {
115+
let unified_err = processing_errors::batch_processing(texts.len(), &e.to_string());
116+
candle_core::Error::from(unified_err)
117+
})?;
118+
119+
let processing_time = start_time.elapsed().as_millis() as u64;
120+
121+
let mut results = Vec::new();
122+
for (predicted_class, confidence) in batch_results {
123+
let intent = if predicted_class < self.intent_labels.len() {
124+
self.intent_labels[predicted_class].clone()
125+
} else {
126+
format!("UNKNOWN_{}", predicted_class)
127+
};
128+
129+
results.push(IntentResult {
130+
intent,
131+
confidence,
132+
processing_time_ms: processing_time,
133+
});
134+
}
135+
136+
Ok(results)
137+
}
138+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//! LoRA Classifiers - High-Performance Parallel Processing
2+
3+
#![allow(dead_code)]
4+
5+
// LoRA classifier modules
6+
pub mod intent_lora;
7+
pub mod parallel_engine;
8+
pub mod pii_lora;
9+
pub mod security_lora;
10+
pub mod token_lora;
11+
12+
// Re-export LoRA classifier types
13+
pub use intent_lora::*;
14+
pub use parallel_engine::*;
15+
pub use pii_lora::*;
16+
pub use security_lora::*;
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//! Parallel LoRA processing engine
2+
//!
3+
//! Enables parallel execution of Intent||PII||Security classification tasks
4+
//! Using thread-based parallelism instead of async/await
5+
6+
use crate::classifiers::lora::{
7+
intent_lora::{IntentLoRAClassifier, IntentResult},
8+
pii_lora::{PIILoRAClassifier, PIIResult},
9+
security_lora::{SecurityLoRAClassifier, SecurityResult},
10+
};
11+
use crate::core::{concurrency_error, ModelErrorType, UnifiedError};
12+
use crate::model_error;
13+
use candle_core::{Device, Result};
14+
use std::sync::{Arc, Mutex};
15+
use std::thread;
16+
17+
/// Parallel LoRA processing engine
18+
pub struct ParallelLoRAEngine {
19+
intent_classifier: Arc<IntentLoRAClassifier>,
20+
pii_classifier: Arc<PIILoRAClassifier>,
21+
security_classifier: Arc<SecurityLoRAClassifier>,
22+
device: Device,
23+
}
24+
25+
impl ParallelLoRAEngine {
26+
pub fn new(
27+
device: Device,
28+
intent_model_path: &str,
29+
pii_model_path: &str,
30+
security_model_path: &str,
31+
use_cpu: bool,
32+
) -> Result<Self> {
33+
// Create intent classifier
34+
let intent_classifier = Arc::new(
35+
IntentLoRAClassifier::new(intent_model_path, use_cpu).map_err(|e| {
36+
let unified_err = model_error!(
37+
ModelErrorType::LoRA,
38+
"intent classifier creation",
39+
format!("Failed to create intent classifier: {}", e),
40+
intent_model_path
41+
);
42+
candle_core::Error::from(unified_err)
43+
})?,
44+
);
45+
46+
// Create PII classifier
47+
let pii_classifier = Arc::new(PIILoRAClassifier::new(pii_model_path, use_cpu).map_err(
48+
|e| {
49+
let unified_err = model_error!(
50+
ModelErrorType::LoRA,
51+
"PII classifier creation",
52+
format!("Failed to create PII classifier: {}", e),
53+
pii_model_path
54+
);
55+
candle_core::Error::from(unified_err)
56+
},
57+
)?);
58+
59+
// Create security classifier
60+
let security_classifier = Arc::new(
61+
SecurityLoRAClassifier::new(security_model_path, use_cpu).map_err(|e| {
62+
let unified_err = model_error!(
63+
ModelErrorType::LoRA,
64+
"security classifier creation",
65+
format!("Failed to create security classifier: {}", e),
66+
security_model_path
67+
);
68+
candle_core::Error::from(unified_err)
69+
})?,
70+
);
71+
72+
Ok(Self {
73+
intent_classifier,
74+
pii_classifier,
75+
security_classifier,
76+
device,
77+
})
78+
}
79+
80+
/// Parallel classification across all three tasks
81+
pub fn parallel_classify(&self, texts: &[&str]) -> Result<ParallelResult> {
82+
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
83+
84+
// Create shared results
85+
let intent_results = Arc::new(Mutex::new(Vec::new()));
86+
let pii_results = Arc::new(Mutex::new(Vec::new()));
87+
let security_results = Arc::new(Mutex::new(Vec::new()));
88+
89+
let handles = vec![
90+
self.spawn_intent_task(texts_owned.clone(), Arc::clone(&intent_results)),
91+
self.spawn_pii_task(texts_owned.clone(), Arc::clone(&pii_results)),
92+
self.spawn_security_task(texts_owned, Arc::clone(&security_results)),
93+
];
94+
95+
// Wait for all threads to complete
96+
for handle in handles {
97+
handle.join().map_err(|_| {
98+
let unified_err = concurrency_error(
99+
"thread join",
100+
"Failed to join parallel classification thread",
101+
);
102+
candle_core::Error::from(unified_err)
103+
})?;
104+
}
105+
106+
Ok(ParallelResult {
107+
intent_results: Arc::try_unwrap(intent_results)
108+
.unwrap()
109+
.into_inner()
110+
.unwrap(),
111+
pii_results: Arc::try_unwrap(pii_results).unwrap().into_inner().unwrap(),
112+
security_results: Arc::try_unwrap(security_results)
113+
.unwrap()
114+
.into_inner()
115+
.unwrap(),
116+
})
117+
}
118+
119+
fn spawn_intent_task(
120+
&self,
121+
texts: Vec<String>,
122+
results: Arc<Mutex<Vec<IntentResult>>>,
123+
) -> thread::JoinHandle<()> {
124+
let classifier = Arc::clone(&self.intent_classifier);
125+
thread::spawn(move || {
126+
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
127+
match classifier.batch_classify(&text_refs) {
128+
Ok(task_results) => {
129+
let mut guard = results.lock().unwrap();
130+
*guard = task_results;
131+
}
132+
Err(e) => {
133+
eprintln!("Intent classification failed: {}", e);
134+
}
135+
}
136+
})
137+
}
138+
139+
fn spawn_pii_task(
140+
&self,
141+
texts: Vec<String>,
142+
results: Arc<Mutex<Vec<PIIResult>>>,
143+
) -> thread::JoinHandle<()> {
144+
let classifier = Arc::clone(&self.pii_classifier);
145+
thread::spawn(move || {
146+
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
147+
if let Ok(task_results) = classifier.batch_detect(&text_refs) {
148+
let mut guard = results.lock().unwrap();
149+
*guard = task_results;
150+
}
151+
})
152+
}
153+
154+
fn spawn_security_task(
155+
&self,
156+
texts: Vec<String>,
157+
results: Arc<Mutex<Vec<SecurityResult>>>,
158+
) -> thread::JoinHandle<()> {
159+
let classifier = Arc::clone(&self.security_classifier);
160+
thread::spawn(move || {
161+
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
162+
match classifier.batch_detect(&text_refs) {
163+
Ok(task_results) => {
164+
let mut guard = results.lock().unwrap();
165+
*guard = task_results;
166+
}
167+
Err(e) => {
168+
eprintln!("Security classification failed: {}", e);
169+
}
170+
}
171+
})
172+
}
173+
}
174+
175+
/// Results from parallel classification
176+
#[derive(Debug, Clone)]
177+
pub struct ParallelResult {
178+
pub intent_results: Vec<IntentResult>,
179+
pub pii_results: Vec<PIIResult>,
180+
pub security_results: Vec<SecurityResult>,
181+
}

0 commit comments

Comments
 (0)