diff --git a/.github/workflows/test-and-build.yml b/.github/workflows/test-and-build.yml index 3eec580c4..7f40eea5e 100644 --- a/.github/workflows/test-and-build.yml +++ b/.github/workflows/test-and-build.yml @@ -62,9 +62,9 @@ jobs: with: path: | models/ - key: ${{ runner.os }}-models-v1-${{ hashFiles('tools/make/models.mk') }} + key: ${{ runner.os }}-models-v2-${{ hashFiles('tools/make/models.mk') }} restore-keys: | - ${{ runner.os }}-models-v1- + ${{ runner.os }}-models-v2- continue-on-error: true # Don't fail the job if caching fails - name: Check go mod tidy diff --git a/candle-binding/src/classifiers/lora/intent_lora.rs b/candle-binding/src/classifiers/lora/intent_lora.rs index 6da64a9a4..e6f653566 100644 --- a/candle-binding/src/classifiers/lora/intent_lora.rs +++ b/candle-binding/src/classifiers/lora/intent_lora.rs @@ -113,6 +113,40 @@ impl IntentLoRAClassifier { }) } + /// Classify intent and return (class_index, confidence, intent_label) for FFI + pub fn classify_with_index(&self, text: &str) -> Result<(usize, f32, String)> { + // Use real BERT model for classification + let (predicted_class, confidence) = + self.bert_classifier.classify_text(text).map_err(|e| { + let unified_err = model_error!( + ModelErrorType::LoRA, + "intent classification", + format!("Classification failed: {}", e), + text + ); + candle_core::Error::from(unified_err) + })?; + + // Map class index to intent label - fail if class not found + let intent = if predicted_class < self.intent_labels.len() { + self.intent_labels[predicted_class].clone() + } else { + let unified_err = model_error!( + ModelErrorType::LoRA, + "intent classification", + format!( + "Invalid class index {} not found in labels (max: {})", + predicted_class, + self.intent_labels.len() + ), + text + ); + return Err(candle_core::Error::from(unified_err)); + }; + + Ok((predicted_class, confidence, intent)) + } + /// Parallel classification for multiple texts using rayon /// /// # Performance diff --git a/candle-binding/src/core/tokenization.rs b/candle-binding/src/core/tokenization.rs index 6f5f7df26..5ac2dd7d5 100644 --- a/candle-binding/src/core/tokenization.rs +++ b/candle-binding/src/core/tokenization.rs @@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer { let encoding = tokenizer .encode(text, self.config.add_special_tokens) .map_err(E::msg)?; - Ok(self.encoding_to_result(&encoding)) + + // Explicitly enforce max_length truncation for LoRA models + // This is a safety check to ensure we never exceed the model's position embedding size + let mut result = self.encoding_to_result(&encoding); + let max_len = self.config.max_length; + if result.token_ids.len() > max_len { + result.token_ids.truncate(max_len); + result.token_ids_u32.truncate(max_len); + result.attention_mask.truncate(max_len); + result.tokens.truncate(max_len); + } + + Ok(result) } fn tokenize_batch_smart( diff --git a/candle-binding/src/ffi/classify.rs b/candle-binding/src/ffi/classify.rs index 91e38baee..51a043328 100644 --- a/candle-binding/src/ffi/classify.rs +++ b/candle-binding/src/ffi/classify.rs @@ -21,7 +21,7 @@ use crate::BertClassifier; use std::ffi::{c_char, CStr}; use std::sync::{Arc, OnceLock}; -use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER}; +use crate::ffi::init::{LORA_INTENT_CLASSIFIER, PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER}; // Import DeBERTa classifier for jailbreak detection use super::init::DEBERTA_JAILBREAK_CLASSIFIER; @@ -693,7 +693,32 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati Err(_) => return default_result, } }; - // Use TraditionalBertClassifier for Candle BERT text classification + + // Try LoRA intent classifier first (preferred for higher accuracy) + if let Some(classifier) = LORA_INTENT_CLASSIFIER.get() { + let classifier = classifier.clone(); + match classifier.classify_with_index(text) { + Ok((class_idx, confidence, ref intent)) => { + // Allocate C string for intent label + let label_ptr = unsafe { allocate_c_string(intent) }; + + return ClassificationResult { + predicted_class: class_idx as i32, + confidence, + label: label_ptr, + }; + } + Err(e) => { + eprintln!( + "LoRA intent classifier error: {}, falling back to Traditional BERT", + e + ); + // Don't return - fall through to Traditional BERT classifier + } + } + } + + // Fallback to Traditional BERT classifier if let Some(classifier) = TRADITIONAL_BERT_CLASSIFIER.get() { let classifier = classifier.clone(); match classifier.classify_text(text) { @@ -717,7 +742,7 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati } } } else { - println!("TraditionalBertClassifier not initialized - call init_bert_classifier first"); + println!("No classifier initialized - call init_candle_bert_classifier first"); ClassificationResult { predicted_class: -1, confidence: 0.0, diff --git a/candle-binding/src/ffi/init.rs b/candle-binding/src/ffi/init.rs index 5d557b42a..86a96ce21 100644 --- a/candle-binding/src/ffi/init.rs +++ b/candle-binding/src/ffi/init.rs @@ -37,6 +37,10 @@ pub static PARALLEL_LORA_ENGINE: OnceLock< pub static LORA_TOKEN_CLASSIFIER: OnceLock< Arc, > = OnceLock::new(); +// LoRA intent classifier for sequence classification +pub static LORA_INTENT_CLASSIFIER: OnceLock< + Arc, +> = OnceLock::new(); /// Model type detection for intelligent routing #[derive(Debug, Clone, PartialEq)] @@ -604,7 +608,6 @@ pub extern "C" fn init_candle_bert_classifier( num_classes: i32, use_cpu: bool, ) -> bool { - // Migrated from lib.rs:1555-1578 let model_path = unsafe { match CStr::from_ptr(model_path).to_str() { Ok(s) => s, @@ -612,20 +615,46 @@ pub extern "C" fn init_candle_bert_classifier( } }; - // Initialize TraditionalBertClassifier - match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new( - model_path, - num_classes as usize, - use_cpu, - ) { - Ok(_classifier) => { - // Store in global static (would need to add this to the lazy_static block) + // Intelligent model type detection (same as token classifier) + let model_type = detect_model_type(model_path); - true + match model_type { + ModelType::LoRA => { + // Check if already initialized + if LORA_INTENT_CLASSIFIER.get().is_some() { + return true; // Already initialized, return success + } + + // Route to LoRA intent classifier initialization + match crate::classifiers::lora::intent_lora::IntentLoRAClassifier::new( + model_path, use_cpu, + ) { + Ok(classifier) => LORA_INTENT_CLASSIFIER.set(Arc::new(classifier)).is_ok(), + Err(e) => { + eprintln!( + " ERROR: Failed to initialize LoRA intent classifier: {}", + e + ); + false + } + } } - Err(e) => { - eprintln!("Failed to initialize Candle BERT classifier: {}", e); - false + ModelType::Traditional => { + // Initialize TraditionalBertClassifier + match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new( + model_path, + num_classes as usize, + use_cpu, + ) { + Ok(_classifier) => { + // Store in global static (would need to add this to the lazy_static block) + true + } + Err(e) => { + eprintln!("Failed to initialize Candle BERT classifier: {}", e); + false + } + } } } } diff --git a/candle-binding/src/model_architectures/lora/bert_lora.rs b/candle-binding/src/model_architectures/lora/bert_lora.rs index dd3df187c..d5e9c3499 100644 --- a/candle-binding/src/model_architectures/lora/bert_lora.rs +++ b/candle-binding/src/model_architectures/lora/bert_lora.rs @@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier { // Load tokenizer let tokenizer_path = Path::new(model_path).join("tokenizer.json"); - let tokenizer = Tokenizer::from_file(&tokenizer_path) + let mut tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?; + // Configure truncation to max 512 tokens (BERT's position embedding limit) + use tokenizers::TruncationParams; + tokenizer + .with_truncation(Some(TruncationParams { + max_length: 512, + ..Default::default() + })) + .map_err(E::msg)?; + // Load model weights let weights_path = if Path::new(model_path).join("model.safetensors").exists() { Path::new(model_path).join("model.safetensors") @@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier { // Load tokenizer let tokenizer_path = Path::new(model_path).join("tokenizer.json"); - let tokenizer = Tokenizer::from_file(&tokenizer_path) + let mut tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?; + // Configure truncation to max 512 tokens (BERT's position embedding limit) + use tokenizers::TruncationParams; + tokenizer + .with_truncation(Some(TruncationParams { + max_length: 512, + ..Default::default() + })) + .map_err(E::msg)?; + // Load model weights let weights_path = if Path::new(model_path).join("model.safetensors").exists() { Path::new(model_path).join("model.safetensors") diff --git a/config/config.yaml b/config/config.yaml index 3454e0d13..1acb405b7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -58,15 +58,14 @@ model_config: # Classifier configuration classifier: category_model: - model_id: "models/category_classifier_modernbert-base_model" - use_modernbert: true + model_id: "models/lora_intent_classifier_bert-base-uncased_model" threshold: 0.6 use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" - use_modernbert: true - threshold: 0.7 + model_id: "models/lora_pii_detector_bert-base-uncased_model" + use_modernbert: false + threshold: 0.9 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/deploy/helm/semantic-router/values.yaml b/deploy/helm/semantic-router/values.yaml index 1ca81118b..b5062227f 100644 --- a/deploy/helm/semantic-router/values.yaml +++ b/deploy/helm/semantic-router/values.yaml @@ -159,6 +159,8 @@ initContainer: repo: Qwen/Qwen3-Embedding-0.6B - name: all-MiniLM-L12-v2 repo: sentence-transformers/all-MiniLM-L12-v2 + - name: lora_intent_classifier_bert-base-uncased_model + repo: LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model - name: category_classifier_modernbert-base_model repo: LLM-Semantic-Router/category_classifier_modernbert-base_model - name: pii_classifier_modernbert-base_model @@ -272,11 +274,11 @@ config: # Classifier configuration classifier: category_model: - model_id: "models/category_classifier_modernbert-base_model" - use_modernbert: true + model_id: "models/lora_intent_classifier_bert-base-uncased_model" + use_modernbert: false # Use LoRA intent classifier with auto-detection threshold: 0.6 use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: model_id: "models/pii_classifier_modernbert-base_presidio_token_model" use_modernbert: true diff --git a/deploy/kubernetes/aibrix/semantic-router-values/values.yaml b/deploy/kubernetes/aibrix/semantic-router-values/values.yaml index ec1d43537..504ca9f87 100644 --- a/deploy/kubernetes/aibrix/semantic-router-values/values.yaml +++ b/deploy/kubernetes/aibrix/semantic-router-values/values.yaml @@ -431,11 +431,11 @@ config: # Classifier configuration classifier: category_model: - model_id: "models/category_classifier_modernbert-base_model" - use_modernbert: true + model_id: "models/lora_intent_classifier_bert-base-uncased_model" + use_modernbert: false # Use LoRA intent classifier with auto-detection threshold: 0.6 use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: # Support both traditional (modernbert) and LoRA-based PII detection # When model_type is "auto", the system will auto-detect LoRA configuration diff --git a/e2e/profiles/ai-gateway/values.yaml b/e2e/profiles/ai-gateway/values.yaml index ed67b6eaf..62dcd9a2c 100644 --- a/e2e/profiles/ai-gateway/values.yaml +++ b/e2e/profiles/ai-gateway/values.yaml @@ -461,17 +461,17 @@ config: # Classifier configuration classifier: category_model: - model_id: "models/category_classifier_modernbert-base_model" - use_modernbert: true + model_id: "models/lora_intent_classifier_bert-base-uncased_model" + use_modernbert: false # Use LoRA intent classifier with auto-detection threshold: 0.6 use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: # Support both traditional (modernbert) and LoRA-based PII detection # When model_type is "auto", the system will auto-detect LoRA configuration model_id: "models/lora_pii_detector_bert-base-uncased_model" use_modernbert: false # Use LoRA PII model with auto-detection - threshold: 0.7 + threshold: 0.9 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/e2e/profiles/dynamic-config/values.yaml b/e2e/profiles/dynamic-config/values.yaml index b14a2b92c..7dc988dfc 100644 --- a/e2e/profiles/dynamic-config/values.yaml +++ b/e2e/profiles/dynamic-config/values.yaml @@ -42,15 +42,15 @@ config: classifier: category_model: - model_id: "models/category_classifier_modernbert-base_model" - use_modernbert: true + model_id: "models/lora_intent_classifier_bert-base-uncased_model" + use_modernbert: false # Use LoRA intent classifier with auto-detection threshold: 0.6 use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: model_id: "models/lora_pii_detector_bert-base-uncased_model" use_modernbert: false # Use LoRA PII model with auto-detection - threshold: 0.7 + threshold: 0.9 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/src/semantic-router/pkg/classification/classifier.go b/src/semantic-router/pkg/classification/classifier.go index cf5d934b3..7ff71cbdb 100644 --- a/src/semantic-router/pkg/classification/classifier.go +++ b/src/semantic-router/pkg/classification/classifier.go @@ -18,34 +18,35 @@ type CategoryInitializer interface { Init(modelID string, useCPU bool, numClasses ...int) error } -type LinearCategoryInitializer struct{} - -func (c *LinearCategoryInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { - err := candle_binding.InitClassifier(modelID, numClasses[0], useCPU) - if err != nil { - return err - } - logging.Infof("Initialized linear category classifier with %d classes", numClasses[0]) - return nil +type CategoryInitializerImpl struct { + usedModernBERT bool // Track which init path succeeded for inference routing } -type ModernBertCategoryInitializer struct{} +func (c *CategoryInitializerImpl) Init(modelID string, useCPU bool, numClasses ...int) error { + // Try auto-detecting Candle BERT init first - checks for lora_config.json + // This enables LoRA Intent/Category models when available + success := candle_binding.InitCandleBertClassifier(modelID, numClasses[0], useCPU) + if success { + c.usedModernBERT = false + logging.Infof("Initialized category classifier with auto-detection (LoRA or Traditional BERT)") + return nil + } -func (c *ModernBertCategoryInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { + // Fallback to ModernBERT-specific init for backward compatibility + // This handles models with incomplete configs (missing hidden_act, etc.) + logging.Infof("Auto-detection failed, falling back to ModernBERT category initializer") err := candle_binding.InitModernBertClassifier(modelID, useCPU) if err != nil { - return err + return fmt.Errorf("failed to initialize category classifier (both auto-detect and ModernBERT): %w", err) } - logging.Infof("Initialized ModernBERT category classifier (classes auto-detected from model)") + c.usedModernBERT = true + logging.Infof("Initialized ModernBERT category classifier (fallback mode)") return nil } -// createCategoryInitializer creates the appropriate category initializer based on configuration -func createCategoryInitializer(useModernBERT bool) CategoryInitializer { - if useModernBERT { - return &ModernBertCategoryInitializer{} - } - return &LinearCategoryInitializer{} +// createCategoryInitializer creates the category initializer (auto-detecting) +func createCategoryInitializer() CategoryInitializer { + return &CategoryInitializerImpl{} } type CategoryInference interface { @@ -53,32 +54,22 @@ type CategoryInference interface { ClassifyWithProbabilities(text string) (candle_binding.ClassResultWithProbs, error) } -type LinearCategoryInference struct{} - -func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { - return candle_binding.ClassifyText(text) -} - -func (c *LinearCategoryInference) ClassifyWithProbabilities(text string) (candle_binding.ClassResultWithProbs, error) { - return candle_binding.ClassifyTextWithProbabilities(text) -} - -type ModernBertCategoryInference struct{} +type CategoryInferenceImpl struct{} -func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { - return candle_binding.ClassifyModernBertText(text) +func (c *CategoryInferenceImpl) Classify(text string) (candle_binding.ClassResult, error) { + // Auto-detecting inference - uses whichever classifier was initialized (LoRA or Traditional) + return candle_binding.ClassifyCandleBertText(text) } -func (c *ModernBertCategoryInference) ClassifyWithProbabilities(text string) (candle_binding.ClassResultWithProbs, error) { +func (c *CategoryInferenceImpl) ClassifyWithProbabilities(text string) (candle_binding.ClassResultWithProbs, error) { + // Note: CandleBert doesn't have WithProbabilities yet, fall back to ModernBERT + // This will work correctly if ModernBERT was initialized as fallback return candle_binding.ClassifyModernBertTextWithProbabilities(text) } -// createCategoryInference creates the appropriate category inference based on configuration -func createCategoryInference(useModernBERT bool) CategoryInference { - if useModernBERT { - return &ModernBertCategoryInference{} - } - return &LinearCategoryInference{} +// createCategoryInference creates the category inference (auto-detecting) +func createCategoryInference() CategoryInference { + return &CategoryInferenceImpl{} } type JailbreakInitializer interface { @@ -368,7 +359,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p // Add in-tree classifier if configured if cfg.CategoryModel.ModelID != "" { - options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.CategoryModel.UseModernBERT), createCategoryInference(cfg.CategoryModel.UseModernBERT))) + options = append(options, withCategory(categoryMapping, createCategoryInitializer(), createCategoryInference())) } // Add MCP classifier if configured @@ -386,7 +377,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p // IsCategoryEnabled checks if category classification is properly configured func (c *Classifier) IsCategoryEnabled() bool { - return c.Config.CategoryModel.ModelID != "" && c.Config.CategoryMappingPath != "" && c.CategoryMapping != nil + return c.Config.CategoryModel.ModelID != "" && c.Config.CategoryModel.CategoryMappingPath != "" && c.CategoryMapping != nil } // initializeCategoryClassifier initializes the category classification model diff --git a/src/semantic-router/pkg/classification/lora_auto_detection_test.go b/src/semantic-router/pkg/classification/lora_auto_detection_test.go new file mode 100644 index 000000000..9d9aad1d2 --- /dev/null +++ b/src/semantic-router/pkg/classification/lora_auto_detection_test.go @@ -0,0 +1,101 @@ +package classification + +import ( + "os" + "testing" + + candle_binding "github.com/vllm-project/semantic-router/candle-binding" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +// TestIntentClassificationLoRAAutoDetection demonstrates that current implementation +// doesn't auto-detect LoRA models for intent classification (unlike PII detection) +func TestIntentClassificationLoRAAutoDetection(t *testing.T) { + modelPath := "../../../../models/lora_intent_classifier_bert-base-uncased_model" + numClasses := 14 // From category_mapping.json + + // Check if LoRA model exists + if _, err := os.Stat(modelPath + "/lora_config.json"); os.IsNotExist(err) { + t.Skip("LoRA intent model not available, skipping test") + } + + t.Run("AutoDetection: CategoryInitializer Now Detects LoRA Models", func(t *testing.T) { + // After fix: CategoryInitializerImpl auto-detects LoRA models + // It tries InitCandleBertClassifier() first (checks for lora_config.json) + // Falls back to InitModernBertClassifier() if needed + + cfg := &config.CategoryModel{ + ModelID: modelPath, + UseCPU: true, + } + + // Create auto-detecting initializer + initializer := createCategoryInitializer() + + // Try to initialize - should SUCCESS with LoRA auto-detection + err := initializer.Init(cfg.ModelID, cfg.UseCPU, numClasses) + if err != nil { + t.Errorf("Auto-detection failed: %v", err) + return + } + + t.Log("✓ CategoryInitializer successfully auto-detected and initialized LoRA model") + + // Verify inference works + inference := createCategoryInference() + result, err := inference.Classify("What is the best business strategy?") + if err != nil { + t.Errorf("Classification failed: %v", err) + return + } + + if result.Class < 0 || result.Class >= numClasses { + t.Errorf("Invalid category: %d (expected 0-%d)", result.Class, numClasses-1) + return + } + + t.Logf("✓ Classification works: category=%d, confidence=%.3f", result.Class, result.Confidence) + }) + + t.Run("Proof: Auto-Detection Already Works in Rust Layer", func(t *testing.T) { + // This proves the Rust auto-detection ALREADY EXISTS and WORKS + // InitCandleBertClassifier has auto-detection built-in (checks for lora_config.json) + + success := candle_binding.InitCandleBertClassifier(modelPath, numClasses, true) + + if !success { + t.Error("InitCandleBertClassifier should auto-detect LoRA (it exists in Rust)") + return + } + + t.Log("✓ Proof: Rust layer successfully auto-detected LoRA model") + + // Try classification to prove it works + result, err := candle_binding.ClassifyCandleBertText("What is the best business strategy?") + if err != nil { + t.Errorf("Classification failed: %v", err) + return + } + + if result.Class < 0 || result.Class >= numClasses { + t.Errorf("Invalid category: %d (expected 0-%d)", result.Class, numClasses-1) + return + } + + t.Logf("✓ Classification works: category=%d, confidence=%.3f", result.Class, result.Confidence) + t.Logf(" Solution: Update CategoryInitializer to use InitCandleBertClassifier") + }) +} + +// TestPIIAlreadyHasAutoDetection shows PII detection already works with LoRA auto-detection +func TestPIIAlreadyHasAutoDetection(t *testing.T) { + modelPath := "models/lora_pii_detector_bert-base-uncased_model" + + // Check if LoRA model exists + if _, err := os.Stat(modelPath + "/lora_config.json"); os.IsNotExist(err) { + t.Skip("LoRA PII model not available, skipping test") + } + + t.Log("✓ PII detection already has auto-detection (implemented in PR #709)") + t.Log(" Goal: Make Intent & Jailbreak detection work the same way") +}