Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 137 additions & 4 deletions src/semantic-router/pkg/services/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func GetGlobalClassificationService() *ClassificationService {

// HasClassifier returns true if the service has a real classifier (not placeholder)
func (s *ClassificationService) HasClassifier() bool {
return s.classifier != nil
return s.unifiedClassifier != nil || s.classifier != nil
}

// NewPlaceholderClassificationService creates a placeholder service for API-only mode
Expand Down Expand Up @@ -118,7 +118,12 @@ func (s *ClassificationService) ClassifyIntent(req IntentRequest) (*IntentRespon
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.ClassifyIntentUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -210,7 +215,12 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error)
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.DetectPIIUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -290,7 +300,12 @@ func (s *ClassificationService) CheckSecurity(req SecurityRequest) (*SecurityRes
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.CheckSecurityUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -454,6 +469,59 @@ func (s *ClassificationService) ClassifyPIIUnified(texts []string) ([]classifica
return results.PIIResults, nil
}

// DetectPIIUnified performs PII detection using unified classifier and returns PIIResponse format
func (s *ClassificationService) DetectPIIUnified(req PIIRequest) (*PIIResponse, error) {
start := time.Now()

if req.Text == "" {
return nil, fmt.Errorf("text cannot be empty")
}

// Use unified classifier for PII detection
piiResults, err := s.ClassifyPIIUnified([]string{req.Text})
if err != nil {
return nil, fmt.Errorf("PII detection failed: %w", err)
}

processingTime := time.Since(start).Milliseconds()

// Convert PIIResult to PIIResponse format
if len(piiResults) == 0 {
return &PIIResponse{
HasPII: false,
Entities: []PIIEntity{},
SecurityRecommendation: "allow",
ProcessingTimeMs: processingTime,
}, nil
}

piiResult := piiResults[0]
response := &PIIResponse{
HasPII: piiResult.HasPII,
Entities: []PIIEntity{},
ProcessingTimeMs: processingTime,
}

// Convert PII types to entities
for _, piiType := range piiResult.PIITypes {
entity := PIIEntity{
Type: piiType,
Value: "[DETECTED]", // Placeholder - unified classifier doesn't provide exact positions yet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token sequence classifier should have the location index. Can you follow up in another PR? Thanks

Confidence: float64(piiResult.Confidence),
}
response.Entities = append(response.Entities, entity)
}

// Set security recommendation
if response.HasPII {
response.SecurityRecommendation = "block"
} else {
response.SecurityRecommendation = "allow"
}

return response, nil
}

// ClassifySecurityUnified performs security detection using unified classifier
func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]classification.SecurityResult, error) {
if s.unifiedClassifier == nil {
Expand All @@ -468,6 +536,71 @@ func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]class
return results.SecurityResults, nil
}

// CheckSecurityUnified performs security detection using unified classifier and returns SecurityResponse format
func (s *ClassificationService) CheckSecurityUnified(req SecurityRequest) (*SecurityResponse, error) {
start := time.Now()

if req.Text == "" {
return nil, fmt.Errorf("text cannot be empty")
}

// Use unified classifier for security detection
securityResults, err := s.ClassifySecurityUnified([]string{req.Text})
if err != nil {
return nil, fmt.Errorf("security detection failed: %w", err)
}

processingTime := time.Since(start).Milliseconds()

// Convert SecurityResult to SecurityResponse format
if len(securityResults) == 0 {
return &SecurityResponse{
IsJailbreak: false,
Copy link
Collaborator

@rootfs rootfs Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why using hardcoded values? no security result can have a special output but a hardcoded response like this is a bit confusing.

RiskScore: 0.1,
DetectionTypes: []string{},
Confidence: 0.9,
Recommendation: "allow",
PatternsDetected: []string{},
ProcessingTimeMs: processingTime,
}, nil
}

securityResult := securityResults[0]
response := &SecurityResponse{
IsJailbreak: securityResult.IsJailbreak,
RiskScore: float64(securityResult.Confidence),
Confidence: float64(securityResult.Confidence),
ProcessingTimeMs: processingTime,
}

// Set detection types based on threat type
if securityResult.ThreatType != "" {
response.DetectionTypes = []string{securityResult.ThreatType}
response.PatternsDetected = []string{securityResult.ThreatType}
} else {
response.DetectionTypes = []string{}
response.PatternsDetected = []string{}
}

// Set recommendation based on jailbreak detection
if response.IsJailbreak {
response.Recommendation = "block"
} else {
response.Recommendation = "allow"
}

// Add reasoning if requested
if req.Options != nil && req.Options.IncludeReasoning {
if response.IsJailbreak {
response.Reasoning = fmt.Sprintf("Detected %s with confidence %.2f", securityResult.ThreatType, securityResult.Confidence)
} else {
response.Reasoning = "No security threats detected"
}
}

return response, nil
}

// HasUnifiedClassifier returns true if the service has a unified classifier
func (s *ClassificationService) HasUnifiedClassifier() bool {
return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()
Expand Down
Loading