Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
441 changes: 0 additions & 441 deletions candle-binding/src/bert_official.rs

This file was deleted.

161 changes: 161 additions & 0 deletions candle-binding/src/classifiers/lora/intent_lora.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//! Intent classification with LoRA adapters
//!
//! High-performance intent classification using real model inference

use crate::core::{processing_errors, ModelErrorType, UnifiedError};
use crate::model_architectures::lora::bert_lora::HighPerformanceBertClassifier;
use crate::model_error;
use candle_core::Result;
use std::time::Instant;

/// Intent classifier with real model inference (merged LoRA models)
pub struct IntentLoRAClassifier {
/// High-performance BERT classifier for intent classification
bert_classifier: HighPerformanceBertClassifier,
/// Confidence threshold for predictions
confidence_threshold: f32,
/// Intent labels mapping
intent_labels: Vec<String>,
/// Model path for reference
model_path: String,
}

/// Intent classification result
#[derive(Debug, Clone)]
pub struct IntentResult {
pub intent: String,
pub confidence: f32,
pub processing_time_ms: u64,
}

impl IntentLoRAClassifier {
/// Create new intent classifier using real model inference
pub fn new(model_path: &str, use_cpu: bool) -> Result<Self> {
// Load labels from model config
let intent_labels = Self::load_labels_from_config(model_path)?;
let num_classes = intent_labels.len();

// Load the high-performance BERT classifier for merged LoRA models
let classifier = HighPerformanceBertClassifier::new(model_path, num_classes, use_cpu)
.map_err(|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classifier creation",
format!("Failed to create BERT classifier: {}", e),
model_path
);
candle_core::Error::from(unified_err)
})?;

// Load threshold from global config instead of hardcoding
let confidence_threshold = {
use crate::core::config_loader::GlobalConfigLoader;
GlobalConfigLoader::load_intent_threshold().unwrap_or(0.6) // Default from config.yaml classifier.category_model.threshold
};

Ok(Self {
bert_classifier: classifier,
confidence_threshold,
intent_labels,
model_path: model_path.to_string(),
})
}

/// Load intent labels from model config.json using unified config loader
fn load_labels_from_config(model_path: &str) -> Result<Vec<String>> {
use crate::core::config_loader;

match config_loader::load_intent_labels(model_path) {
Ok(result) => Ok(result),
Err(unified_err) => Err(candle_core::Error::from(unified_err)),
}
}

/// Classify intent using real model inference
pub fn classify_intent(&self, text: &str) -> Result<IntentResult> {
let start_time = Instant::now();

// Use real BERT model for classification
let (predicted_class, confidence) =
self.bert_classifier.classify_text(text).map_err(|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classification",
format!("Classification failed: {}", e),
text
);
candle_core::Error::from(unified_err)
})?;

// Map class index to intent label - fail if class not found
let intent = if predicted_class < self.intent_labels.len() {
self.intent_labels[predicted_class].clone()
} else {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classification",
format!(
"Invalid class index {} not found in labels (max: {})",
predicted_class,
self.intent_labels.len()
),
text
);
return Err(candle_core::Error::from(unified_err));
};

let processing_time = start_time.elapsed().as_millis() as u64;

Ok(IntentResult {
intent,
confidence,
processing_time_ms: processing_time,
})
}

/// Parallel classification for multiple texts
pub fn parallel_classify(&self, texts: &[&str]) -> Result<Vec<IntentResult>> {
// Process each text using real model inference
texts
.iter()
.map(|text| self.classify_intent(text))
.collect()
}

/// Batch classification for multiple texts (optimized)
pub fn batch_classify(&self, texts: &[&str]) -> Result<Vec<IntentResult>> {
let start_time = Instant::now();

// Use BERT's batch processing capability
let batch_results = self.bert_classifier.classify_batch(texts).map_err(|e| {
let unified_err = processing_errors::batch_processing(texts.len(), &e.to_string());
candle_core::Error::from(unified_err)
})?;

let processing_time = start_time.elapsed().as_millis() as u64;

let mut results = Vec::new();
for (i, (predicted_class, confidence)) in batch_results.iter().enumerate() {
let intent = if *predicted_class < self.intent_labels.len() {
self.intent_labels[*predicted_class].clone()
} else {
let unified_err = model_error!(
ModelErrorType::LoRA,
"batch intent classification",
format!("Invalid class index {} not found in labels (max: {}) for text at position {}",
predicted_class, self.intent_labels.len(), i),
&format!("batch[{}]", i)
);
return Err(candle_core::Error::from(unified_err));
};

results.push(IntentResult {
intent,
confidence: *confidence,
processing_time_ms: processing_time,
});
}

Ok(results)
}
}
16 changes: 16 additions & 0 deletions candle-binding/src/classifiers/lora/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//! LoRA Classifiers - High-Performance Parallel Processing

#![allow(dead_code)]

// LoRA classifier modules
pub mod intent_lora;
pub mod parallel_engine;
pub mod pii_lora;
pub mod security_lora;
pub mod token_lora;

// Re-export LoRA classifier types
pub use intent_lora::*;
pub use parallel_engine::*;
pub use pii_lora::*;
pub use security_lora::*;
181 changes: 181 additions & 0 deletions candle-binding/src/classifiers/lora/parallel_engine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
//! Parallel LoRA processing engine
//!
//! Enables parallel execution of Intent||PII||Security classification tasks
//! Using thread-based parallelism instead of async/await

use crate::classifiers::lora::{
intent_lora::{IntentLoRAClassifier, IntentResult},
pii_lora::{PIILoRAClassifier, PIIResult},
security_lora::{SecurityLoRAClassifier, SecurityResult},
};
use crate::core::{concurrency_error, ModelErrorType, UnifiedError};
use crate::model_error;
use candle_core::{Device, Result};
use std::sync::{Arc, Mutex};
use std::thread;

/// Parallel LoRA processing engine
pub struct ParallelLoRAEngine {
intent_classifier: Arc<IntentLoRAClassifier>,
pii_classifier: Arc<PIILoRAClassifier>,
security_classifier: Arc<SecurityLoRAClassifier>,
device: Device,
}

impl ParallelLoRAEngine {
pub fn new(
device: Device,
intent_model_path: &str,
pii_model_path: &str,
security_model_path: &str,
use_cpu: bool,
) -> Result<Self> {
// Create intent classifier
let intent_classifier = Arc::new(
IntentLoRAClassifier::new(intent_model_path, use_cpu).map_err(|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classifier creation",
format!("Failed to create intent classifier: {}", e),
intent_model_path
);
candle_core::Error::from(unified_err)
})?,
);

// Create PII classifier
let pii_classifier = Arc::new(PIILoRAClassifier::new(pii_model_path, use_cpu).map_err(
|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"PII classifier creation",
format!("Failed to create PII classifier: {}", e),
pii_model_path
);
candle_core::Error::from(unified_err)
},
)?);

// Create security classifier
let security_classifier = Arc::new(
SecurityLoRAClassifier::new(security_model_path, use_cpu).map_err(|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"security classifier creation",
format!("Failed to create security classifier: {}", e),
security_model_path
);
candle_core::Error::from(unified_err)
})?,
);

Ok(Self {
intent_classifier,
pii_classifier,
security_classifier,
device,
})
}

/// Parallel classification across all three tasks
pub fn parallel_classify(&self, texts: &[&str]) -> Result<ParallelResult> {
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();

// Create shared results
let intent_results = Arc::new(Mutex::new(Vec::new()));
let pii_results = Arc::new(Mutex::new(Vec::new()));
let security_results = Arc::new(Mutex::new(Vec::new()));

let handles = vec![
self.spawn_intent_task(texts_owned.clone(), Arc::clone(&intent_results)),
self.spawn_pii_task(texts_owned.clone(), Arc::clone(&pii_results)),
self.spawn_security_task(texts_owned, Arc::clone(&security_results)),
];

// Wait for all threads to complete
for handle in handles {
handle.join().map_err(|_| {
let unified_err = concurrency_error(
"thread join",
"Failed to join parallel classification thread",
);
candle_core::Error::from(unified_err)
})?;
}

Ok(ParallelResult {
intent_results: Arc::try_unwrap(intent_results)
.unwrap()
.into_inner()
.unwrap(),
pii_results: Arc::try_unwrap(pii_results).unwrap().into_inner().unwrap(),
security_results: Arc::try_unwrap(security_results)
.unwrap()
.into_inner()
.unwrap(),
})
}

fn spawn_intent_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<IntentResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.intent_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
match classifier.batch_classify(&text_refs) {
Ok(task_results) => {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
Err(e) => {
eprintln!("Intent classification failed: {}", e);
}
}
})
}

fn spawn_pii_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<PIIResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.pii_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
if let Ok(task_results) = classifier.batch_detect(&text_refs) {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
})
}

fn spawn_security_task(
&self,
texts: Vec<String>,
results: Arc<Mutex<Vec<SecurityResult>>>,
) -> thread::JoinHandle<()> {
let classifier = Arc::clone(&self.security_classifier);
thread::spawn(move || {
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
match classifier.batch_detect(&text_refs) {
Ok(task_results) => {
let mut guard = results.lock().unwrap();
*guard = task_results;
}
Err(e) => {
eprintln!("Security classification failed: {}", e);
}
}
})
}
}

/// Results from parallel classification
#[derive(Debug, Clone)]
pub struct ParallelResult {
pub intent_results: Vec<IntentResult>,
pub pii_results: Vec<PIIResult>,
pub security_results: Vec<SecurityResult>,
}
Loading
Loading