Skip to content

Commit 9ff34d0

Browse files
committed
feat: enable LoRA auto-detection for intent classification (#724)
This commit implements automatic detection of LoRA (Low-Rank Adaptation) models based on the presence of lora_config.json in the model directory. Changes: - Add LoRA auto-detection logic in Rust candle-binding layer - Implement fallback to BERT base model when LoRA config is not found - Add comprehensive test coverage for auto-detection mechanism - Update default Helm values to use LoRA intent classification model - Update ABrix deployment values to use LoRA models The auto-detection mechanism checks for lora_config.json during model initialization and automatically switches between LoRA and base BERT models without requiring explicit configuration changes. Signed-off-by: Yossi Ovadia <[email protected]>
1 parent 86b6ae2 commit 9ff34d0

File tree

8 files changed

+246
-25
lines changed

8 files changed

+246
-25
lines changed

candle-binding/src/classifiers/lora/intent_lora.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,40 @@ impl IntentLoRAClassifier {
113113
})
114114
}
115115

116+
/// Classify intent and return (class_index, confidence, intent_label) for FFI
117+
pub fn classify_with_index(&self, text: &str) -> Result<(usize, f32, String)> {
118+
// Use real BERT model for classification
119+
let (predicted_class, confidence) =
120+
self.bert_classifier.classify_text(text).map_err(|e| {
121+
let unified_err = model_error!(
122+
ModelErrorType::LoRA,
123+
"intent classification",
124+
format!("Classification failed: {}", e),
125+
text
126+
);
127+
candle_core::Error::from(unified_err)
128+
})?;
129+
130+
// Map class index to intent label - fail if class not found
131+
let intent = if predicted_class < self.intent_labels.len() {
132+
self.intent_labels[predicted_class].clone()
133+
} else {
134+
let unified_err = model_error!(
135+
ModelErrorType::LoRA,
136+
"intent classification",
137+
format!(
138+
"Invalid class index {} not found in labels (max: {})",
139+
predicted_class,
140+
self.intent_labels.len()
141+
),
142+
text
143+
);
144+
return Err(candle_core::Error::from(unified_err));
145+
};
146+
147+
Ok((predicted_class, confidence, intent))
148+
}
149+
116150
/// Parallel classification for multiple texts using rayon
117151
///
118152
/// # Performance

candle-binding/src/core/tokenization.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer {
387387
let encoding = tokenizer
388388
.encode(text, self.config.add_special_tokens)
389389
.map_err(E::msg)?;
390-
Ok(self.encoding_to_result(&encoding))
390+
391+
// Explicitly enforce max_length truncation for LoRA models
392+
// This is a safety check to ensure we never exceed the model's position embedding size
393+
let mut result = self.encoding_to_result(&encoding);
394+
let max_len = self.config.max_length;
395+
if result.token_ids.len() > max_len {
396+
result.token_ids.truncate(max_len);
397+
result.token_ids_u32.truncate(max_len);
398+
result.attention_mask.truncate(max_len);
399+
result.tokens.truncate(max_len);
400+
}
401+
402+
Ok(result)
391403
}
392404

393405
fn tokenize_batch_smart(

candle-binding/src/ffi/classify.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::BertClassifier;
2121
use std::ffi::{c_char, CStr};
2222
use std::sync::{Arc, OnceLock};
2323

24-
use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
24+
use crate::ffi::init::{LORA_INTENT_CLASSIFIER, PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
2525
// Import DeBERTa classifier for jailbreak detection
2626
use super::init::DEBERTA_JAILBREAK_CLASSIFIER;
2727

@@ -693,7 +693,32 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
693693
Err(_) => return default_result,
694694
}
695695
};
696-
// Use TraditionalBertClassifier for Candle BERT text classification
696+
697+
// Try LoRA intent classifier first (preferred for higher accuracy)
698+
if let Some(classifier) = LORA_INTENT_CLASSIFIER.get() {
699+
let classifier = classifier.clone();
700+
match classifier.classify_with_index(text) {
701+
Ok((class_idx, confidence, ref intent)) => {
702+
// Allocate C string for intent label
703+
let label_ptr = unsafe { allocate_c_string(intent) };
704+
705+
return ClassificationResult {
706+
predicted_class: class_idx as i32,
707+
confidence,
708+
label: label_ptr,
709+
};
710+
}
711+
Err(e) => {
712+
eprintln!(
713+
"LoRA intent classifier error: {}, falling back to Traditional BERT",
714+
e
715+
);
716+
// Don't return - fall through to Traditional BERT classifier
717+
}
718+
}
719+
}
720+
721+
// Fallback to Traditional BERT classifier
697722
if let Some(classifier) = TRADITIONAL_BERT_CLASSIFIER.get() {
698723
let classifier = classifier.clone();
699724
match classifier.classify_text(text) {
@@ -717,7 +742,7 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
717742
}
718743
}
719744
} else {
720-
println!("TraditionalBertClassifier not initialized - call init_bert_classifier first");
745+
println!("No classifier initialized - call init_candle_bert_classifier first");
721746
ClassificationResult {
722747
predicted_class: -1,
723748
confidence: 0.0,

candle-binding/src/ffi/init.rs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ pub static PARALLEL_LORA_ENGINE: OnceLock<
3737
pub static LORA_TOKEN_CLASSIFIER: OnceLock<
3838
Arc<crate::classifiers::lora::token_lora::LoRATokenClassifier>,
3939
> = OnceLock::new();
40+
// LoRA intent classifier for sequence classification
41+
pub static LORA_INTENT_CLASSIFIER: OnceLock<
42+
Arc<crate::classifiers::lora::intent_lora::IntentLoRAClassifier>,
43+
> = OnceLock::new();
4044

4145
/// Model type detection for intelligent routing
4246
#[derive(Debug, Clone, PartialEq)]
@@ -604,28 +608,53 @@ pub extern "C" fn init_candle_bert_classifier(
604608
num_classes: i32,
605609
use_cpu: bool,
606610
) -> bool {
607-
// Migrated from lib.rs:1555-1578
608611
let model_path = unsafe {
609612
match CStr::from_ptr(model_path).to_str() {
610613
Ok(s) => s,
611614
Err(_) => return false,
612615
}
613616
};
614617

615-
// Initialize TraditionalBertClassifier
616-
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
617-
model_path,
618-
num_classes as usize,
619-
use_cpu,
620-
) {
621-
Ok(_classifier) => {
622-
// Store in global static (would need to add this to the lazy_static block)
618+
// Intelligent model type detection (same as token classifier)
619+
let model_type = detect_model_type(model_path);
623620

624-
true
621+
match model_type {
622+
ModelType::LoRA => {
623+
// Check if already initialized
624+
if LORA_INTENT_CLASSIFIER.get().is_some() {
625+
return true; // Already initialized, return success
626+
}
627+
628+
// Route to LoRA intent classifier initialization
629+
match crate::classifiers::lora::intent_lora::IntentLoRAClassifier::new(
630+
model_path, use_cpu,
631+
) {
632+
Ok(classifier) => LORA_INTENT_CLASSIFIER.set(Arc::new(classifier)).is_ok(),
633+
Err(e) => {
634+
eprintln!(
635+
" ERROR: Failed to initialize LoRA intent classifier: {}",
636+
e
637+
);
638+
false
639+
}
640+
}
625641
}
626-
Err(e) => {
627-
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
628-
false
642+
ModelType::Traditional => {
643+
// Initialize TraditionalBertClassifier
644+
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
645+
model_path,
646+
num_classes as usize,
647+
use_cpu,
648+
) {
649+
Ok(_classifier) => {
650+
// Store in global static (would need to add this to the lazy_static block)
651+
true
652+
}
653+
Err(e) => {
654+
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
655+
false
656+
}
657+
}
629658
}
630659
}
631660
}

candle-binding/src/model_architectures/lora/bert_lora.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier {
499499

500500
// Load tokenizer
501501
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
502-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
502+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
503503
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
504504

505+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
506+
use tokenizers::TruncationParams;
507+
tokenizer
508+
.with_truncation(Some(TruncationParams {
509+
max_length: 512,
510+
..Default::default()
511+
}))
512+
.map_err(E::msg)?;
513+
505514
// Load model weights
506515
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
507516
Path::new(model_path).join("model.safetensors")
@@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier {
690699

691700
// Load tokenizer
692701
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
693-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
702+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
694703
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
695704

705+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
706+
use tokenizers::TruncationParams;
707+
tokenizer
708+
.with_truncation(Some(TruncationParams {
709+
max_length: 512,
710+
..Default::default()
711+
}))
712+
.map_err(E::msg)?;
713+
696714
// Load model weights
697715
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
698716
Path::new(model_path).join("model.safetensors")

deploy/helm/semantic-router/values.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ initContainer:
159159
repo: Qwen/Qwen3-Embedding-0.6B
160160
- name: all-MiniLM-L12-v2
161161
repo: sentence-transformers/all-MiniLM-L12-v2
162+
- name: lora_intent_classifier_bert-base-uncased_model
163+
repo: LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model
162164
- name: category_classifier_modernbert-base_model
163165
repo: LLM-Semantic-Router/category_classifier_modernbert-base_model
164166
- name: pii_classifier_modernbert-base_model
@@ -272,11 +274,11 @@ config:
272274
# Classifier configuration
273275
classifier:
274276
category_model:
275-
model_id: "models/category_classifier_modernbert-base_model"
276-
use_modernbert: true
277+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
278+
use_modernbert: false # Use LoRA intent classifier with auto-detection
277279
threshold: 0.6
278280
use_cpu: true
279-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
281+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
280282
pii_model:
281283
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
282284
use_modernbert: true

deploy/kubernetes/aibrix/semantic-router-values/values.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,11 @@ config:
431431
# Classifier configuration
432432
classifier:
433433
category_model:
434-
model_id: "models/category_classifier_modernbert-base_model"
435-
use_modernbert: true
434+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
435+
use_modernbert: false # Use LoRA intent classifier with auto-detection
436436
threshold: 0.6
437437
use_cpu: true
438-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
438+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
439439
pii_model:
440440
# Support both traditional (modernbert) and LoRA-based PII detection
441441
# When model_type is "auto", the system will auto-detect LoRA configuration
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package classification
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
8+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
9+
)
10+
11+
// TestIntentClassificationLoRAAutoDetection demonstrates that current implementation
12+
// doesn't auto-detect LoRA models for intent classification (unlike PII detection)
13+
func TestIntentClassificationLoRAAutoDetection(t *testing.T) {
14+
modelPath := "../../../../models/lora_intent_classifier_bert-base-uncased_model"
15+
numClasses := 14 // From category_mapping.json
16+
17+
// Check if LoRA model exists
18+
if _, err := os.Stat(modelPath + "/lora_config.json"); os.IsNotExist(err) {
19+
t.Skip("LoRA intent model not available, skipping test")
20+
}
21+
22+
t.Run("AutoDetection: CategoryInitializer Now Detects LoRA Models", func(t *testing.T) {
23+
// After fix: CategoryInitializerImpl auto-detects LoRA models
24+
// It tries InitCandleBertClassifier() first (checks for lora_config.json)
25+
// Falls back to InitModernBertClassifier() if needed
26+
27+
cfg := &config.CategoryModel{
28+
ModelID: modelPath,
29+
UseCPU: true,
30+
}
31+
32+
// Create auto-detecting initializer
33+
initializer := createCategoryInitializer()
34+
35+
// Try to initialize - should SUCCESS with LoRA auto-detection
36+
err := initializer.Init(cfg.ModelID, cfg.UseCPU, numClasses)
37+
if err != nil {
38+
t.Errorf("Auto-detection failed: %v", err)
39+
return
40+
}
41+
42+
t.Log("✓ CategoryInitializer successfully auto-detected and initialized LoRA model")
43+
44+
// Verify inference works
45+
inference := createCategoryInference()
46+
result, err := inference.Classify("What is the best business strategy?")
47+
if err != nil {
48+
t.Errorf("Classification failed: %v", err)
49+
return
50+
}
51+
52+
if result.Class < 0 || result.Class >= numClasses {
53+
t.Errorf("Invalid category: %d (expected 0-%d)", result.Class, numClasses-1)
54+
return
55+
}
56+
57+
t.Logf("✓ Classification works: category=%d, confidence=%.3f", result.Class, result.Confidence)
58+
})
59+
60+
t.Run("Proof: Auto-Detection Already Works in Rust Layer", func(t *testing.T) {
61+
// This proves the Rust auto-detection ALREADY EXISTS and WORKS
62+
// InitCandleBertClassifier has auto-detection built-in (checks for lora_config.json)
63+
64+
success := candle_binding.InitCandleBertClassifier(modelPath, numClasses, true)
65+
66+
if !success {
67+
t.Error("InitCandleBertClassifier should auto-detect LoRA (it exists in Rust)")
68+
return
69+
}
70+
71+
t.Log("✓ Proof: Rust layer successfully auto-detected LoRA model")
72+
73+
// Try classification to prove it works
74+
result, err := candle_binding.ClassifyCandleBertText("What is the best business strategy?")
75+
if err != nil {
76+
t.Errorf("Classification failed: %v", err)
77+
return
78+
}
79+
80+
if result.Class < 0 || result.Class >= numClasses {
81+
t.Errorf("Invalid category: %d (expected 0-%d)", result.Class, numClasses-1)
82+
return
83+
}
84+
85+
t.Logf("✓ Classification works: category=%d, confidence=%.3f", result.Class, result.Confidence)
86+
t.Logf(" Solution: Update CategoryInitializer to use InitCandleBertClassifier")
87+
})
88+
}
89+
90+
// TestPIIAlreadyHasAutoDetection shows PII detection already works with LoRA auto-detection
91+
func TestPIIAlreadyHasAutoDetection(t *testing.T) {
92+
modelPath := "models/lora_pii_detector_bert-base-uncased_model"
93+
94+
// Check if LoRA model exists
95+
if _, err := os.Stat(modelPath + "/lora_config.json"); os.IsNotExist(err) {
96+
t.Skip("LoRA PII model not available, skipping test")
97+
}
98+
99+
t.Log("✓ PII detection already has auto-detection (implemented in PR #709)")
100+
t.Log(" Goal: Make Intent & Jailbreak detection work the same way")
101+
}

0 commit comments

Comments
 (0)