Skip to content

Commit 292170b

Browse files
yossiovadiarootfsclaude
authored
feat(classifier): enable LoRA auto-detection for intent classification (#726)
* feat: enable LoRA auto-detection for intent classification (#724) This commit implements automatic detection of LoRA (Low-Rank Adaptation) models based on the presence of lora_config.json in the model directory. Changes: - Add LoRA auto-detection logic in Rust candle-binding layer - Implement fallback to BERT base model when LoRA config is not found - Add comprehensive test coverage for auto-detection mechanism - Update default Helm values to use LoRA intent classification model - Update ABrix deployment values to use LoRA models The auto-detection mechanism checks for lora_config.json during model initialization and automatically switches between LoRA and base BERT models without requiring explicit configuration changes. Signed-off-by: Yossi Ovadia <[email protected]> * fix: enable LoRA intent classification and optimize PII threshold This commit fixes two critical issues affecting classification accuracy: 1. Fixed IsCategoryEnabled() to check correct config field path: - Changed from c.Config.CategoryMappingPath (non-existent) - To c.Config.CategoryModel.CategoryMappingPath (correct) - This bug prevented LoRA classification from running in e2e tests 2. Optimized PII detection threshold from 0.7 to 0.9: - Reduces false positives from aggressive LoRA PII model (PR #709) - Improves domain classification accuracy from 40.71% to 52.50% - Beats ModernBERT baseline of ~50% Updated e2e test configurations to use LoRA models with optimized thresholds across ai-gateway and dynamic-config profiles. Signed-off-by: Yossi Ovadia <[email protected]> * fix(ci): bump model cache version to pick up lora_config.json Increment cache version from v15 to v16 to ensure CI downloads the updated LoRA models that include lora_config.json files needed for auto-detection. Signed-off-by: Yossi Ovadia <[email protected]> * chore: switch default config to use LoRA models with optimized thresholds Update default configuration to use LoRA-based classification: - Intent classification: lora_intent_classifier_bert-base-uncased_model - PII detection: lora_pii_detector_bert-base-uncased_model with threshold 0.9 This aligns the default config with e2e test configurations for consistency across all environments. Signed-off-by: Yossi Ovadia <[email protected]> * fix(e2e): add decision routes for all 14 LoRA categories in test profiles The production-stack and llm-d E2E test profiles were failing with 0-1% domain classification accuracy because they only configured decision routes for 1-2 categories while using LoRA intent classifiers that classify into 14 categories. When the classifier correctly identified categories like "biology", "health", or "math", no matching decision existed, causing "decision evaluation failed: no decision matched" errors. Changes: - production-stack: Added decision routes for all 14 categories (business, philosophy, biology, health, computer science, engineering, psychology, math, chemistry, physics, history, law, economics, other) - llm-d: Added decision routes for all 14 categories with intelligent grouping (sciences, social sciences, humanities) Results: - production-stack domain classification: 1% → 53% accuracy (50x improvement) - All 12 production-stack E2E tests now pass This fix ensures LoRA auto-detection works properly by providing decision routes for all categories that the classifier can identify. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> Signed-off-by: Yossi Ovadia <[email protected]> --------- Signed-off-by: Yossi Ovadia <[email protected]> Co-authored-by: Huamin Chen <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 9061431 commit 292170b

File tree

15 files changed

+641
-95
lines changed

15 files changed

+641
-95
lines changed

.github/workflows/test-and-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ jobs:
8585
with:
8686
path: |
8787
models/
88-
key: ${{ runner.os }}-models-v1-${{ hashFiles('tools/make/models.mk') }}
88+
key: ${{ runner.os }}-models-v2-${{ hashFiles('tools/make/models.mk') }}
8989
restore-keys: |
90-
${{ runner.os }}-models-v1-
90+
${{ runner.os }}-models-v2-
9191
continue-on-error: true # Don't fail the job if caching fails
9292

9393
- name: Check go mod tidy

candle-binding/src/classifiers/lora/intent_lora.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,40 @@ impl IntentLoRAClassifier {
113113
})
114114
}
115115

116+
/// Classify intent and return (class_index, confidence, intent_label) for FFI
117+
pub fn classify_with_index(&self, text: &str) -> Result<(usize, f32, String)> {
118+
// Use real BERT model for classification
119+
let (predicted_class, confidence) =
120+
self.bert_classifier.classify_text(text).map_err(|e| {
121+
let unified_err = model_error!(
122+
ModelErrorType::LoRA,
123+
"intent classification",
124+
format!("Classification failed: {}", e),
125+
text
126+
);
127+
candle_core::Error::from(unified_err)
128+
})?;
129+
130+
// Map class index to intent label - fail if class not found
131+
let intent = if predicted_class < self.intent_labels.len() {
132+
self.intent_labels[predicted_class].clone()
133+
} else {
134+
let unified_err = model_error!(
135+
ModelErrorType::LoRA,
136+
"intent classification",
137+
format!(
138+
"Invalid class index {} not found in labels (max: {})",
139+
predicted_class,
140+
self.intent_labels.len()
141+
),
142+
text
143+
);
144+
return Err(candle_core::Error::from(unified_err));
145+
};
146+
147+
Ok((predicted_class, confidence, intent))
148+
}
149+
116150
/// Parallel classification for multiple texts using rayon
117151
///
118152
/// # Performance

candle-binding/src/core/tokenization.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer {
387387
let encoding = tokenizer
388388
.encode(text, self.config.add_special_tokens)
389389
.map_err(E::msg)?;
390-
Ok(self.encoding_to_result(&encoding))
390+
391+
// Explicitly enforce max_length truncation for LoRA models
392+
// This is a safety check to ensure we never exceed the model's position embedding size
393+
let mut result = self.encoding_to_result(&encoding);
394+
let max_len = self.config.max_length;
395+
if result.token_ids.len() > max_len {
396+
result.token_ids.truncate(max_len);
397+
result.token_ids_u32.truncate(max_len);
398+
result.attention_mask.truncate(max_len);
399+
result.tokens.truncate(max_len);
400+
}
401+
402+
Ok(result)
391403
}
392404

393405
fn tokenize_batch_smart(

candle-binding/src/ffi/classify.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::BertClassifier;
2222
use std::ffi::{c_char, CStr};
2323
use std::sync::{Arc, OnceLock};
2424

25-
use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
25+
use crate::ffi::init::{LORA_INTENT_CLASSIFIER, PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
2626
// Import DeBERTa classifier for jailbreak detection
2727
use super::init::DEBERTA_JAILBREAK_CLASSIFIER;
2828

@@ -734,7 +734,32 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
734734
Err(_) => return default_result,
735735
}
736736
};
737-
// Use TraditionalBertClassifier for Candle BERT text classification
737+
738+
// Try LoRA intent classifier first (preferred for higher accuracy)
739+
if let Some(classifier) = LORA_INTENT_CLASSIFIER.get() {
740+
let classifier = classifier.clone();
741+
match classifier.classify_with_index(text) {
742+
Ok((class_idx, confidence, ref intent)) => {
743+
// Allocate C string for intent label
744+
let label_ptr = unsafe { allocate_c_string(intent) };
745+
746+
return ClassificationResult {
747+
predicted_class: class_idx as i32,
748+
confidence,
749+
label: label_ptr,
750+
};
751+
}
752+
Err(e) => {
753+
eprintln!(
754+
"LoRA intent classifier error: {}, falling back to Traditional BERT",
755+
e
756+
);
757+
// Don't return - fall through to Traditional BERT classifier
758+
}
759+
}
760+
}
761+
762+
// Fallback to Traditional BERT classifier
738763
if let Some(classifier) = TRADITIONAL_BERT_CLASSIFIER.get() {
739764
let classifier = classifier.clone();
740765
match classifier.classify_text(text) {
@@ -758,7 +783,7 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
758783
}
759784
}
760785
} else {
761-
println!("TraditionalBertClassifier not initialized - call init_bert_classifier first");
786+
println!("No classifier initialized - call init_candle_bert_classifier first");
762787
ClassificationResult {
763788
predicted_class: -1,
764789
confidence: 0.0,

candle-binding/src/ffi/init.rs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ pub static PARALLEL_LORA_ENGINE: OnceLock<
3737
pub static LORA_TOKEN_CLASSIFIER: OnceLock<
3838
Arc<crate::classifiers::lora::token_lora::LoRATokenClassifier>,
3939
> = OnceLock::new();
40+
// LoRA intent classifier for sequence classification
41+
pub static LORA_INTENT_CLASSIFIER: OnceLock<
42+
Arc<crate::classifiers::lora::intent_lora::IntentLoRAClassifier>,
43+
> = OnceLock::new();
4044

4145
/// Model type detection for intelligent routing
4246
#[derive(Debug, Clone, PartialEq)]
@@ -604,28 +608,53 @@ pub extern "C" fn init_candle_bert_classifier(
604608
num_classes: i32,
605609
use_cpu: bool,
606610
) -> bool {
607-
// Migrated from lib.rs:1555-1578
608611
let model_path = unsafe {
609612
match CStr::from_ptr(model_path).to_str() {
610613
Ok(s) => s,
611614
Err(_) => return false,
612615
}
613616
};
614617

615-
// Initialize TraditionalBertClassifier
616-
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
617-
model_path,
618-
num_classes as usize,
619-
use_cpu,
620-
) {
621-
Ok(_classifier) => {
622-
// Store in global static (would need to add this to the lazy_static block)
618+
// Intelligent model type detection (same as token classifier)
619+
let model_type = detect_model_type(model_path);
623620

624-
true
621+
match model_type {
622+
ModelType::LoRA => {
623+
// Check if already initialized
624+
if LORA_INTENT_CLASSIFIER.get().is_some() {
625+
return true; // Already initialized, return success
626+
}
627+
628+
// Route to LoRA intent classifier initialization
629+
match crate::classifiers::lora::intent_lora::IntentLoRAClassifier::new(
630+
model_path, use_cpu,
631+
) {
632+
Ok(classifier) => LORA_INTENT_CLASSIFIER.set(Arc::new(classifier)).is_ok(),
633+
Err(e) => {
634+
eprintln!(
635+
" ERROR: Failed to initialize LoRA intent classifier: {}",
636+
e
637+
);
638+
false
639+
}
640+
}
625641
}
626-
Err(e) => {
627-
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
628-
false
642+
ModelType::Traditional => {
643+
// Initialize TraditionalBertClassifier
644+
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
645+
model_path,
646+
num_classes as usize,
647+
use_cpu,
648+
) {
649+
Ok(_classifier) => {
650+
// Store in global static (would need to add this to the lazy_static block)
651+
true
652+
}
653+
Err(e) => {
654+
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
655+
false
656+
}
657+
}
629658
}
630659
}
631660
}

candle-binding/src/model_architectures/lora/bert_lora.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier {
499499

500500
// Load tokenizer
501501
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
502-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
502+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
503503
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
504504

505+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
506+
use tokenizers::TruncationParams;
507+
tokenizer
508+
.with_truncation(Some(TruncationParams {
509+
max_length: 512,
510+
..Default::default()
511+
}))
512+
.map_err(E::msg)?;
513+
505514
// Load model weights
506515
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
507516
Path::new(model_path).join("model.safetensors")
@@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier {
690699

691700
// Load tokenizer
692701
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
693-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
702+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
694703
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
695704

705+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
706+
use tokenizers::TruncationParams;
707+
tokenizer
708+
.with_truncation(Some(TruncationParams {
709+
max_length: 512,
710+
..Default::default()
711+
}))
712+
.map_err(E::msg)?;
713+
696714
// Load model weights
697715
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
698716
Path::new(model_path).join("model.safetensors")

config/config.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@ model_config:
5858
# Classifier configuration
5959
classifier:
6060
category_model:
61-
model_id: "models/category_classifier_modernbert-base_model"
62-
use_modernbert: true
61+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
6362
threshold: 0.6
6463
use_cpu: true
65-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
64+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
6665
pii_model:
67-
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
68-
use_modernbert: true
69-
threshold: 0.7
66+
model_id: "models/lora_pii_detector_bert-base-uncased_model"
67+
use_modernbert: false
68+
threshold: 0.9
7069
use_cpu: true
7170
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"
7271

deploy/helm/semantic-router/values.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ initContainer:
165165
repo: Qwen/Qwen3-Embedding-0.6B
166166
- name: all-MiniLM-L12-v2
167167
repo: sentence-transformers/all-MiniLM-L12-v2
168+
- name: lora_intent_classifier_bert-base-uncased_model
169+
repo: LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model
168170
- name: category_classifier_modernbert-base_model
169171
repo: LLM-Semantic-Router/category_classifier_modernbert-base_model
170172
- name: pii_classifier_modernbert-base_model
@@ -278,11 +280,11 @@ config:
278280
# Classifier configuration
279281
classifier:
280282
category_model:
281-
model_id: "models/category_classifier_modernbert-base_model"
282-
use_modernbert: true
283+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
284+
use_modernbert: false # Use LoRA intent classifier with auto-detection
283285
threshold: 0.6
284286
use_cpu: true
285-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
287+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
286288
pii_model:
287289
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
288290
use_modernbert: true

deploy/kubernetes/aibrix/semantic-router-values/values.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,11 @@ config:
433433
# Classifier configuration
434434
classifier:
435435
category_model:
436-
model_id: "models/category_classifier_modernbert-base_model"
437-
use_modernbert: true
436+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
437+
use_modernbert: false # Use LoRA intent classifier with auto-detection
438438
threshold: 0.6
439439
use_cpu: true
440-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
440+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
441441
pii_model:
442442
# Support both traditional (modernbert) and LoRA-based PII detection
443443
# When model_type is "auto", the system will auto-detect LoRA configuration

e2e/profiles/ai-gateway/values.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,17 +509,17 @@ config:
509509
# Classifier configuration
510510
classifier:
511511
category_model:
512-
model_id: "models/category_classifier_modernbert-base_model"
513-
use_modernbert: true
512+
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
513+
use_modernbert: false # Use LoRA intent classifier with auto-detection
514514
threshold: 0.6
515515
use_cpu: true
516-
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
516+
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
517517
pii_model:
518518
# Support both traditional (modernbert) and LoRA-based PII detection
519519
# When model_type is "auto", the system will auto-detect LoRA configuration
520520
model_id: "models/lora_pii_detector_bert-base-uncased_model"
521521
use_modernbert: false # Use LoRA PII model with auto-detection
522-
threshold: 0.7
522+
threshold: 0.9
523523
use_cpu: true
524524
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"
525525

0 commit comments

Comments
 (0)