Skip to content

Commit 8e88027

Browse files
committed
fix(pii): enable LoRA PII auto-detection with minimal changes
Switch PII classification from hardcoded ModernBERT to auto-detecting Candle BERT classifier. The Rust layer already has built-in auto-detection that checks for lora_config.json and routes to LoRA or Traditional models. Changes: 1. Init: Use InitCandleBertTokenClassifier (has auto-detect built-in) 2. Inference: Use ClassifyCandleBertTokens (auto-routes to initialized classifier) This enables LoRA PII models to work automatically without config changes, providing higher confidence scores for PII entity detection. Fixes #647
1 parent a149800 commit 8e88027

File tree

6 files changed

+50
-20
lines changed

6 files changed

+50
-20
lines changed

deploy/helm/semantic-router/values.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ initContainer:
167167
repo: LLM-Semantic-Router/jailbreak_classifier_modernbert-base_model
168168
- name: pii_classifier_modernbert-base_presidio_token_model
169169
repo: LLM-Semantic-Router/pii_classifier_modernbert-base_presidio_token_model
170+
# LoRA PII detector (for auto-detection feature)
171+
- name: lora_pii_detector_bert-base-uncased_model
172+
repo: LLM-Semantic-Router/lora_pii_detector_bert-base-uncased_model
170173

171174

172175
# Autoscaling configuration

deploy/kubernetes/aibrix/semantic-router-values/values.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,10 @@ config:
437437
use_cpu: true
438438
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
439439
pii_model:
440-
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
441-
use_modernbert: true
440+
# Support both traditional (modernbert) and LoRA-based PII detection
441+
# When model_type is "auto", the system will auto-detect LoRA configuration
442+
model_id: "models/lora_pii_detector_bert-base-uncased_model"
443+
use_modernbert: false # Use LoRA PII model with auto-detection
442444
threshold: 0.7
443445
use_cpu: true
444446
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

e2e/profiles/ai-gateway/values.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,10 @@ config:
467467
use_cpu: true
468468
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
469469
pii_model:
470-
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
471-
use_modernbert: true
470+
# Support both traditional (modernbert) and LoRA-based PII detection
471+
# When model_type is "auto", the system will auto-detect LoRA configuration
472+
model_id: "models/lora_pii_detector_bert-base-uncased_model"
473+
use_modernbert: false # Use LoRA PII model with auto-detection
472474
threshold: 0.7
473475
use_cpu: true
474476
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

e2e/profiles/dynamic-config/values.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ config:
4848
use_cpu: true
4949
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
5050
pii_model:
51-
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
52-
use_modernbert: true
51+
model_id: "models/lora_pii_detector_bert-base-uncased_model"
52+
use_modernbert: false # Use LoRA PII model with auto-detection
5353
threshold: 0.7
5454
use_cpu: true
5555
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

src/semantic-router/pkg/classification/classifier.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,35 +140,55 @@ func createJailbreakInference(useModernBERT bool) JailbreakInference {
140140
}
141141

142142
type PIIInitializer interface {
143-
Init(modelID string, useCPU bool) error
143+
Init(modelID string, useCPU bool, numClasses int) error
144144
}
145145

146-
type ModernBertPIIInitializer struct{}
146+
type PIIInitializerImpl struct {
147+
usedModernBERT bool // Track which init path succeeded for inference routing
148+
}
149+
150+
func (c *PIIInitializerImpl) Init(modelID string, useCPU bool, numClasses int) error {
151+
// Try auto-detecting Candle BERT init first - checks for lora_config.json
152+
// This enables LoRA PII models when available
153+
success := candle_binding.InitCandleBertTokenClassifier(modelID, numClasses, useCPU)
154+
if success {
155+
c.usedModernBERT = false
156+
logging.Infof("Initialized PII token classifier with auto-detection (LoRA or Traditional BERT)")
157+
return nil
158+
}
147159

148-
func (c *ModernBertPIIInitializer) Init(modelID string, useCPU bool) error {
160+
// Fallback to ModernBERT-specific init for backward compatibility
161+
// This handles models with incomplete configs (missing hidden_act, etc.)
162+
logging.Infof("Auto-detection failed, falling back to ModernBERT PII initializer")
149163
err := candle_binding.InitModernBertPIITokenClassifier(modelID, useCPU)
150164
if err != nil {
151-
return err
165+
return fmt.Errorf("failed to initialize PII token classifier (both auto-detect and ModernBERT): %w", err)
152166
}
153-
logging.Infof("Initialized ModernBERT PII token classifier for entity detection")
167+
c.usedModernBERT = true
168+
logging.Infof("Initialized ModernBERT PII token classifier (fallback mode)")
154169
return nil
155170
}
156171

157-
// createPIIInitializer creates the appropriate PII initializer (currently only ModernBERT)
158-
func createPIIInitializer() PIIInitializer { return &ModernBertPIIInitializer{} }
172+
// createPIIInitializer creates the PII initializer (auto-detecting)
173+
func createPIIInitializer() PIIInitializer {
174+
return &PIIInitializerImpl{}
175+
}
159176

160177
type PIIInference interface {
161178
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
162179
}
163180

164-
type ModernBertPIIInference struct{}
181+
type PIIInferenceImpl struct{}
165182

166-
func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
167-
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
183+
func (c *PIIInferenceImpl) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
184+
// Auto-detecting inference - uses whichever classifier was initialized (LoRA or Traditional)
185+
return candle_binding.ClassifyCandleBertTokens(text)
168186
}
169187

170-
// createPIIInference creates the appropriate PII inference (currently only ModernBERT)
171-
func createPIIInference() PIIInference { return &ModernBertPIIInference{} }
188+
// createPIIInference creates the PII inference (auto-detecting)
189+
func createPIIInference() PIIInference {
190+
return &PIIInferenceImpl{}
191+
}
172192

173193
// JailbreakDetection represents the result of jailbreak analysis for a piece of content
174194
type JailbreakDetection struct {
@@ -348,7 +368,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
348368

349369
// Add in-tree classifier if configured
350370
if cfg.CategoryModel.ModelID != "" {
351-
options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.UseModernBERT), createCategoryInference(cfg.UseModernBERT)))
371+
options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.CategoryModel.UseModernBERT), createCategoryInference(cfg.CategoryModel.UseModernBERT)))
352372
}
353373

354374
// Add MCP classifier if configured
@@ -509,7 +529,8 @@ func (c *Classifier) initializePIIClassifier() error {
509529
return fmt.Errorf("not enough PII types for classification, need at least 2, got %d", numPIIClasses)
510530
}
511531

512-
return c.piiInitializer.Init(c.Config.PIIModel.ModelID, c.Config.PIIModel.UseCPU)
532+
// Pass numClasses to support auto-detection
533+
return c.piiInitializer.Init(c.Config.PIIModel.ModelID, c.Config.PIIModel.UseCPU, numPIIClasses)
513534
}
514535

515536
// EvaluateAllRules evaluates all rule types and returns matched rule names

src/semantic-router/pkg/extproc/extproc_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,8 @@ var _ = Describe("Caching Functionality", func() {
20302030
BeforeEach(func() {
20312031
cfg = CreateTestConfig()
20322032
cfg.Enabled = true
2033+
// Disable PII detection for caching tests (not needed and avoids model loading issues)
2034+
cfg.InlineModels.Classifier.PIIModel.ModelID = ""
20332035

20342036
var err error
20352037
router, err = CreateTestRouter(cfg)

0 commit comments

Comments
 (0)