diff --git a/config/testing/config.e2e.yaml b/config/testing/config.e2e.yaml index 2fb33d98b..17f06807b 100644 --- a/config/testing/config.e2e.yaml +++ b/config/testing/config.e2e.yaml @@ -69,11 +69,11 @@ classifier: use_cpu: true category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json" pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" # TODO: Use local model for now before the code can download the entire model from huggingface - use_modernbert: true + model_id: "models/lora_pii_detector_bert-base-uncased_model" + use_modernbert: false # BERT-based LoRA model (this field is ignored - always auto-detects) threshold: 0.7 use_cpu: true - pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" + pii_mapping_path: "models/lora_pii_detector_bert-base-uncased_model/pii_type_mapping.json" categories: - name: business description: "Business and management related queries" @@ -359,6 +359,24 @@ decisions: enabled: true pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER", "US_SSN", "CREDIT_CARD"] + # Default catch-all decision for unmatched requests (E2E PII test fix) + # This ensures PII detection is always enabled, even when no specific decision matches + - name: "default_decision" + description: "Default catch-all decision - blocks all PII for safety" + priority: 1 # Lowest priority - only matches if nothing else does + rules: + operator: "OR" + conditions: + - type: "always" # Always matches as fallback + modelRefs: + - model: "Model-B" + use_reasoning: false + plugins: + - type: "pii" + configuration: + enabled: true + pii_types_allowed: [] # Block ALL PII - empty list means nothing allowed + default_model: "Model-A" # API Configuration diff --git a/src/semantic-router/pkg/classification/classifier.go b/src/semantic-router/pkg/classification/classifier.go index 8cfb8ee41..cf5d934b3 100644 --- a/src/semantic-router/pkg/classification/classifier.go +++ b/src/semantic-router/pkg/classification/classifier.go @@ -884,6 +884,67 @@ func (c *Classifier) ClassifyPIIWithThreshold(text string, threshold float32) ([ return result, nil } +// ClassifyPIIWithDetails performs PII token classification and returns full entity details including confidence scores +func (c *Classifier) ClassifyPIIWithDetails(text string) ([]PIIDetection, error) { + return c.ClassifyPIIWithDetailsAndThreshold(text, c.Config.PIIModel.Threshold) +} + +// ClassifyPIIWithDetailsAndThreshold performs PII token classification with a custom threshold and returns full entity details +func (c *Classifier) ClassifyPIIWithDetailsAndThreshold(text string, threshold float32) ([]PIIDetection, error) { + if !c.IsPIIEnabled() { + return []PIIDetection{}, fmt.Errorf("PII detection is not properly configured") + } + + if text == "" { + return []PIIDetection{}, nil + } + + // Use PII token classifier for entity detection + configPath := fmt.Sprintf("%s/config.json", c.Config.PIIModel.ModelID) + start := time.Now() + tokenResult, err := c.piiInference.ClassifyTokens(text, configPath) + metrics.RecordClassifierLatency("pii", time.Since(start).Seconds()) + if err != nil { + return nil, fmt.Errorf("PII token classification error: %w", err) + } + + if len(tokenResult.Entities) > 0 { + logging.Infof("PII token classification found %d entities", len(tokenResult.Entities)) + } + + // Convert token entities to PII detections, filtering by threshold + var detections []PIIDetection + for _, entity := range tokenResult.Entities { + if entity.Confidence >= threshold { + detection := PIIDetection{ + EntityType: entity.EntityType, + Start: entity.Start, + End: entity.End, + Text: entity.Text, + Confidence: entity.Confidence, + } + detections = append(detections, detection) + logging.Infof("Detected PII entity: %s ('%s') at [%d-%d] with confidence %.3f", + entity.EntityType, entity.Text, entity.Start, entity.End, entity.Confidence) + } + } + + if len(detections) > 0 { + // Log unique PII types for compatibility with existing logs + uniqueTypes := make(map[string]bool) + for _, d := range detections { + uniqueTypes[d.EntityType] = true + } + types := make([]string, 0, len(uniqueTypes)) + for t := range uniqueTypes { + types = append(types, t) + } + logging.Infof("Detected PII types: %v", types) + } + + return detections, nil +} + // DetectPIIInContent performs PII classification on all provided content func (c *Classifier) DetectPIIInContent(allContent []string) []string { var detectedPII []string diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index f83ed9e59..ec76305f1 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -290,8 +290,8 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error) }, nil } - // Perform PII detection using the existing classifier - piiTypes, err := s.classifier.ClassifyPII(req.Text) + // Perform PII detection using the classifier with full details + detections, err := s.classifier.ClassifyPIIWithDetails(req.Text) if err != nil { return nil, fmt.Errorf("PII detection failed: %w", err) } @@ -300,17 +300,19 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error) // Build response response := &PIIResponse{ - HasPII: len(piiTypes) > 0, + HasPII: len(detections) > 0, Entities: []PIIEntity{}, ProcessingTimeMs: processingTime, } - // Convert PII types to entities (simplified for now) - for _, piiType := range piiTypes { + // Convert PII detections to API entities with actual confidence scores + for _, detection := range detections { entity := PIIEntity{ - Type: piiType, - Value: "[DETECTED]", // Placeholder - would need actual entity extraction - Confidence: 0.9, // Placeholder - would need actual confidence + Type: detection.EntityType, + Value: "[DETECTED]", // Redacted for security + Confidence: float64(detection.Confidence), // Actual confidence from model + StartPos: detection.Start, + EndPos: detection.End, } response.Entities = append(response.Entities, entity) }