Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 98 additions & 6 deletions src/semantic-router/pkg/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,22 @@ func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numCl
return nil
}

type Qwen3GuardJailbreakInitializer struct{}

func (c *Qwen3GuardJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
err := candle_binding.InitQwen3Guard(modelID)
if err != nil {
return err
}
logging.Infof("Initialized Qwen3Guard jailbreak classifier")
return nil
}

// createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration
func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer {
func createJailbreakInitializer(useModernBERT bool, useQwen3Guard bool) JailbreakInitializer {
if useQwen3Guard {
return &Qwen3GuardJailbreakInitializer{}
}
if useModernBERT {
return &ModernBertJailbreakInitializer{}
}
Expand All @@ -130,8 +144,74 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla
return candle_binding.ClassifyModernBertJailbreakText(text)
}

type Qwen3GuardJailbreakInference struct{}

func (c *Qwen3GuardJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
// Use Qwen3Guard to classify the text
result, err := candle_binding.ClassifyPromptSafety(text)
if err != nil {
return candle_binding.ClassResult{}, fmt.Errorf("qwen3guard classification failed: %w", err)
}

// Convert SafetyClassificationResult to ClassResult
// Class 0 = safe/benign, Class 1 = jailbreak
// Check if "Jailbreak" is in categories or if SafetyLabel is "Unsafe" or "Controversial"
isJailbreak := false
confidence := float32(0.0)

// Check for jailbreak category
for _, cat := range result.Categories {
if cat == "Jailbreak" {
isJailbreak = true
confidence = 0.9 // High confidence if jailbreak category is detected
break
}
}

// If no jailbreak category but unsafe/controversial, still consider it risky
if !isJailbreak && (result.SafetyLabel == "Unsafe" || result.SafetyLabel == "Controversial") {
// Check if any unsafe categories are present
unsafeCategories := []string{"Violent", "Non-violent Illegal Acts", "Sexual Content or Sexual Acts",
"Suicide & Self-Harm", "Unethical Acts", "Politically Sensitive Topics", "Copyright Violation"}
for _, cat := range result.Categories {
for _, unsafeCat := range unsafeCategories {
if cat == unsafeCat {
isJailbreak = true
confidence = 0.7 // Medium confidence for other unsafe content
break
}
}
if isJailbreak {
break
}
}
}

// If safe, set confidence based on safety label
if !isJailbreak {
if result.SafetyLabel == "Safe" {
confidence = 0.95 // High confidence for safe content
} else {
confidence = 0.5 // Low confidence if label is unclear
}
}

class := 0 // safe/benign
if isJailbreak {
class = 1 // jailbreak
}

return candle_binding.ClassResult{
Class: class,
Confidence: confidence,
}, nil
}

// createJailbreakInference creates the appropriate jailbreak inference based on configuration
func createJailbreakInference(useModernBERT bool) JailbreakInference {
func createJailbreakInference(useModernBERT bool, useQwen3Guard bool) JailbreakInference {
if useQwen3Guard {
return &Qwen3GuardJailbreakInference{}
}
if useModernBERT {
return &ModernBertJailbreakInference{}
}
Expand Down Expand Up @@ -321,7 +401,7 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
// allowing flexible deployment scenarios such as gradual migration.
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) (*Classifier, error) {
options := []option{
withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT), createJailbreakInference(cfg.PromptGuard.UseModernBERT)),
withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard), createJailbreakInference(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard)),
withPII(piiMapping, createPIIInitializer(), createPIIInference()),
}

Expand Down Expand Up @@ -393,9 +473,21 @@ func (c *Classifier) initializeJailbreakClassifier() error {
return fmt.Errorf("jailbreak detection is not properly configured")
}

numClasses := c.JailbreakMapping.GetJailbreakTypeCount()
if numClasses < 2 {
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
// Qwen3Guard doesn't require numClasses, but other models do
// For Qwen3Guard, we still need the mapping for type name lookup, but numClasses is optional
var numClasses int
if c.Config.PromptGuard.UseQwen3Guard {
// Qwen3Guard doesn't use numClasses, but we still need at least 2 for the mapping
numClasses = c.JailbreakMapping.GetJailbreakTypeCount()
if numClasses < 2 {
// For Qwen3Guard, we can work with just 2 classes (benign and jailbreak)
numClasses = 2
}
} else {
numClasses = c.JailbreakMapping.GetJailbreakTypeCount()
if numClasses < 2 {
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
}
}

return c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses)
Expand Down
3 changes: 3 additions & 0 deletions src/semantic-router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ type PromptGuardConfig struct {
// Use ModernBERT for jailbreak detection
UseModernBERT bool `yaml:"use_modernbert"`

// Use Qwen3Guard for jailbreak detection (generative model)
UseQwen3Guard bool `yaml:"use_qwen3guard"`

// Path to the jailbreak type mapping file
JailbreakMappingPath string `yaml:"jailbreak_mapping_path"`
}
Expand Down
Loading