diff --git a/src/semantic-router/pkg/classification/classifier.go b/src/semantic-router/pkg/classification/classifier.go index 608132600..c9248b491 100644 --- a/src/semantic-router/pkg/classification/classifier.go +++ b/src/semantic-router/pkg/classification/classifier.go @@ -106,8 +106,22 @@ func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numCl return nil } +type Qwen3GuardJailbreakInitializer struct{} + +func (c *Qwen3GuardJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { + err := candle_binding.InitQwen3Guard(modelID) + if err != nil { + return err + } + logging.Infof("Initialized Qwen3Guard jailbreak classifier") + return nil +} + // createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration -func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer { +func createJailbreakInitializer(useModernBERT bool, useQwen3Guard bool) JailbreakInitializer { + if useQwen3Guard { + return &Qwen3GuardJailbreakInitializer{} + } if useModernBERT { return &ModernBertJailbreakInitializer{} } @@ -130,8 +144,74 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla return candle_binding.ClassifyModernBertJailbreakText(text) } +type Qwen3GuardJailbreakInference struct{} + +func (c *Qwen3GuardJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + // Use Qwen3Guard to classify the text + result, err := candle_binding.ClassifyPromptSafety(text) + if err != nil { + return candle_binding.ClassResult{}, fmt.Errorf("qwen3guard classification failed: %w", err) + } + + // Convert SafetyClassificationResult to ClassResult + // Class 0 = safe/benign, Class 1 = jailbreak + // Check if "Jailbreak" is in categories or if SafetyLabel is "Unsafe" or "Controversial" + isJailbreak := false + confidence := float32(0.0) + + // Check for jailbreak category + for _, cat := range result.Categories { + if cat == "Jailbreak" { + isJailbreak = true + confidence = 0.9 // High confidence if jailbreak category is detected + break + } + } + + // If no jailbreak category but unsafe/controversial, still consider it risky + if !isJailbreak && (result.SafetyLabel == "Unsafe" || result.SafetyLabel == "Controversial") { + // Check if any unsafe categories are present + unsafeCategories := []string{"Violent", "Non-violent Illegal Acts", "Sexual Content or Sexual Acts", + "Suicide & Self-Harm", "Unethical Acts", "Politically Sensitive Topics", "Copyright Violation"} + for _, cat := range result.Categories { + for _, unsafeCat := range unsafeCategories { + if cat == unsafeCat { + isJailbreak = true + confidence = 0.7 // Medium confidence for other unsafe content + break + } + } + if isJailbreak { + break + } + } + } + + // If safe, set confidence based on safety label + if !isJailbreak { + if result.SafetyLabel == "Safe" { + confidence = 0.95 // High confidence for safe content + } else { + confidence = 0.5 // Low confidence if label is unclear + } + } + + class := 0 // safe/benign + if isJailbreak { + class = 1 // jailbreak + } + + return candle_binding.ClassResult{ + Class: class, + Confidence: confidence, + }, nil +} + // createJailbreakInference creates the appropriate jailbreak inference based on configuration -func createJailbreakInference(useModernBERT bool) JailbreakInference { +func createJailbreakInference(useModernBERT bool, useQwen3Guard bool) JailbreakInference { + if useQwen3Guard { + return &Qwen3GuardJailbreakInference{} + } if useModernBERT { return &ModernBertJailbreakInference{} } @@ -321,7 +401,7 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla // allowing flexible deployment scenarios such as gradual migration. func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) (*Classifier, error) { options := []option{ - withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT), createJailbreakInference(cfg.PromptGuard.UseModernBERT)), + withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard), createJailbreakInference(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard)), withPII(piiMapping, createPIIInitializer(), createPIIInference()), } @@ -393,9 +473,21 @@ func (c *Classifier) initializeJailbreakClassifier() error { return fmt.Errorf("jailbreak detection is not properly configured") } - numClasses := c.JailbreakMapping.GetJailbreakTypeCount() - if numClasses < 2 { - return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses) + // Qwen3Guard doesn't require numClasses, but other models do + // For Qwen3Guard, we still need the mapping for type name lookup, but numClasses is optional + var numClasses int + if c.Config.PromptGuard.UseQwen3Guard { + // Qwen3Guard doesn't use numClasses, but we still need at least 2 for the mapping + numClasses = c.JailbreakMapping.GetJailbreakTypeCount() + if numClasses < 2 { + // For Qwen3Guard, we can work with just 2 classes (benign and jailbreak) + numClasses = 2 + } + } else { + numClasses = c.JailbreakMapping.GetJailbreakTypeCount() + if numClasses < 2 { + return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses) + } } return c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses) diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index cb1546edd..b3277a659 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -336,6 +336,9 @@ type PromptGuardConfig struct { // Use ModernBERT for jailbreak detection UseModernBERT bool `yaml:"use_modernbert"` + // Use Qwen3Guard for jailbreak detection (generative model) + UseQwen3Guard bool `yaml:"use_qwen3guard"` + // Path to the jailbreak type mapping file JailbreakMappingPath string `yaml:"jailbreak_mapping_path"` }