diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index 1240e1e5..fbf99ee4 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -71,7 +71,7 @@ func GetGlobalClassificationService() *ClassificationService { // HasClassifier returns true if the service has a real classifier (not placeholder) func (s *ClassificationService) HasClassifier() bool { - return s.classifier != nil + return s.unifiedClassifier != nil || s.classifier != nil } // NewPlaceholderClassificationService creates a placeholder service for API-only mode @@ -118,7 +118,12 @@ func (s *ClassificationService) ClassifyIntent(req IntentRequest) (*IntentRespon return nil, fmt.Errorf("text cannot be empty") } - // Check if classifier is available + // Prioritize unified classifier if available + if s.unifiedClassifier != nil { + return s.ClassifyIntentUnified(req) + } + + // Check if legacy classifier is available if s.classifier == nil { // Return placeholder response processingTime := time.Since(start).Milliseconds() @@ -210,7 +215,12 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error) return nil, fmt.Errorf("text cannot be empty") } - // Check if classifier is available + // Prioritize unified classifier if available + if s.unifiedClassifier != nil { + return s.DetectPIIUnified(req) + } + + // Check if legacy classifier is available if s.classifier == nil { // Return placeholder response processingTime := time.Since(start).Milliseconds() @@ -290,7 +300,12 @@ func (s *ClassificationService) CheckSecurity(req SecurityRequest) (*SecurityRes return nil, fmt.Errorf("text cannot be empty") } - // Check if classifier is available + // Prioritize unified classifier if available + if s.unifiedClassifier != nil { + return s.CheckSecurityUnified(req) + } + + // Check if legacy classifier is available if s.classifier == nil { // Return placeholder response processingTime := time.Since(start).Milliseconds() @@ -454,6 +469,59 @@ func (s *ClassificationService) ClassifyPIIUnified(texts []string) ([]classifica return results.PIIResults, nil } +// DetectPIIUnified performs PII detection using unified classifier and returns PIIResponse format +func (s *ClassificationService) DetectPIIUnified(req PIIRequest) (*PIIResponse, error) { + start := time.Now() + + if req.Text == "" { + return nil, fmt.Errorf("text cannot be empty") + } + + // Use unified classifier for PII detection + piiResults, err := s.ClassifyPIIUnified([]string{req.Text}) + if err != nil { + return nil, fmt.Errorf("PII detection failed: %w", err) + } + + processingTime := time.Since(start).Milliseconds() + + // Convert PIIResult to PIIResponse format + if len(piiResults) == 0 { + return &PIIResponse{ + HasPII: false, + Entities: []PIIEntity{}, + SecurityRecommendation: "allow", + ProcessingTimeMs: processingTime, + }, nil + } + + piiResult := piiResults[0] + response := &PIIResponse{ + HasPII: piiResult.HasPII, + Entities: []PIIEntity{}, + ProcessingTimeMs: processingTime, + } + + // Convert PII types to entities + for _, piiType := range piiResult.PIITypes { + entity := PIIEntity{ + Type: piiType, + Value: "[DETECTED]", // Placeholder - unified classifier doesn't provide exact positions yet + Confidence: float64(piiResult.Confidence), + } + response.Entities = append(response.Entities, entity) + } + + // Set security recommendation + if response.HasPII { + response.SecurityRecommendation = "block" + } else { + response.SecurityRecommendation = "allow" + } + + return response, nil +} + // ClassifySecurityUnified performs security detection using unified classifier func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]classification.SecurityResult, error) { if s.unifiedClassifier == nil { @@ -468,6 +536,71 @@ func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]class return results.SecurityResults, nil } +// CheckSecurityUnified performs security detection using unified classifier and returns SecurityResponse format +func (s *ClassificationService) CheckSecurityUnified(req SecurityRequest) (*SecurityResponse, error) { + start := time.Now() + + if req.Text == "" { + return nil, fmt.Errorf("text cannot be empty") + } + + // Use unified classifier for security detection + securityResults, err := s.ClassifySecurityUnified([]string{req.Text}) + if err != nil { + return nil, fmt.Errorf("security detection failed: %w", err) + } + + processingTime := time.Since(start).Milliseconds() + + // Convert SecurityResult to SecurityResponse format + if len(securityResults) == 0 { + return &SecurityResponse{ + IsJailbreak: false, + RiskScore: 0.1, + DetectionTypes: []string{}, + Confidence: 0.9, + Recommendation: "allow", + PatternsDetected: []string{}, + ProcessingTimeMs: processingTime, + }, nil + } + + securityResult := securityResults[0] + response := &SecurityResponse{ + IsJailbreak: securityResult.IsJailbreak, + RiskScore: float64(securityResult.Confidence), + Confidence: float64(securityResult.Confidence), + ProcessingTimeMs: processingTime, + } + + // Set detection types based on threat type + if securityResult.ThreatType != "" { + response.DetectionTypes = []string{securityResult.ThreatType} + response.PatternsDetected = []string{securityResult.ThreatType} + } else { + response.DetectionTypes = []string{} + response.PatternsDetected = []string{} + } + + // Set recommendation based on jailbreak detection + if response.IsJailbreak { + response.Recommendation = "block" + } else { + response.Recommendation = "allow" + } + + // Add reasoning if requested + if req.Options != nil && req.Options.IncludeReasoning { + if response.IsJailbreak { + response.Reasoning = fmt.Sprintf("Detected %s with confidence %.2f", securityResult.ThreatType, securityResult.Confidence) + } else { + response.Reasoning = "No security threats detected" + } + } + + return response, nil +} + // HasUnifiedClassifier returns true if the service has a unified classifier func (s *ClassificationService) HasUnifiedClassifier() bool { return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()