Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions candle-binding/semantic-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ type SimResult struct {

// ClassResult represents the result of a text classification
type ClassResult struct {
Class int // Class index
Confidence float32 // Confidence score
Class int // Class index
Confidence float32 // Confidence score
Categories []string // Violation categories (e.g., "Violent", "Jailbreak") - only populated when unsafe/controversial
}

// ClassResultWithProbs represents the result of a text classification with full probability distribution
Expand Down
117 changes: 84 additions & 33 deletions src/semantic-router/pkg/classification/vllm_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ func (v *VLLMJailbreakInference) Classify(text string) (candle_binding.ClassResu
// Parse model output - flexible to support multiple formats
output := resp.Choices[0].Message.Content
logging.Debugf("vLLM jailbreak detection response: %s", output)
isJailbreak, confidence := v.parseSafetyOutput(output)
logging.Debugf("Parsed result: isJailbreak=%v, confidence=%.3f", isJailbreak, confidence)
isJailbreak, confidence, categories := v.parseSafetyOutput(output)
logging.Debugf("Parsed result: isJailbreak=%v, confidence=%.3f, categories=%v",
isJailbreak, confidence, categories)

// Map to ClassResult format
// Class: 0 = safe, 1 = jailbreak/unsafe
Expand All @@ -86,43 +87,55 @@ func (v *VLLMJailbreakInference) Classify(text string) (candle_binding.ClassResu
class = 1
}

return candle_binding.ClassResult{
result := candle_binding.ClassResult{
Class: class,
Confidence: confidence,
}, nil
}

// Only populate categories when content is unsafe or controversial
// (empty slice for safe content or when categories not available)
if isJailbreak && len(categories) > 0 {
result.Categories = categories
}

return result, nil
}

// parseSafetyOutput parses safety model output - uses parser type or auto-detection
func (v *VLLMJailbreakInference) parseSafetyOutput(output string) (bool, float32) {
func (v *VLLMJailbreakInference) parseSafetyOutput(output string) (bool, float32, []string) {
// Determine parser type based on configuration or model name
parserType := v.determineParserType()

switch parserType {
case "qwen3guard":
return v.parseQwen3GuardFormat(output)
case "json":
return v.parseJSONFormat(output)
isJailbreak, conf := v.parseJSONFormat(output)
return isJailbreak, conf, nil // JSON parser doesn't support categories yet
case "simple":
return v.parseSimpleFormat(output)
isJailbreak, conf := v.parseSimpleFormat(output)
return isJailbreak, conf, nil // Simple parser doesn't support categories yet
case "auto":
// Try all parsers (OR logic) until one succeeds
if result, conf := v.parseQwen3GuardFormat(output); conf > 0.1 {
return result, conf
if result, conf, cats := v.parseQwen3GuardFormat(output); conf > 0.1 {
return result, conf, cats
}
if result, conf := v.parseJSONFormat(output); conf > 0.1 {
return result, conf
return result, conf, nil
}
return v.parseSimpleFormat(output)
isJailbreak, conf := v.parseSimpleFormat(output)
return isJailbreak, conf, nil
default:
logging.Warnf("Unknown parser type: %s, using auto", parserType)
// Fallback to auto mode
if result, conf := v.parseQwen3GuardFormat(output); conf > 0.1 {
return result, conf
if result, conf, cats := v.parseQwen3GuardFormat(output); conf > 0.1 {
return result, conf, cats
}
if result, conf := v.parseJSONFormat(output); conf > 0.1 {
return result, conf
return result, conf, nil
}
return v.parseSimpleFormat(output)
isJailbreak, conf := v.parseSimpleFormat(output)
return isJailbreak, conf, nil
}
}

Expand All @@ -147,7 +160,7 @@ func (v *VLLMJailbreakInference) determineParserType() string {
}

// parseQwen3GuardFormat parses Qwen3Guard structured output
func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, float32) {
func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, float32, []string) {
// Qwen3Guard actual format: "Safety: Unsafe\nCategories: ..."
// Try "Safety:" format first (actual Qwen3Guard output)
safetyRegex := regexp.MustCompile(`(?i)safety:\s*(safe|unsafe|controversial)`)
Expand All @@ -166,8 +179,12 @@ func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, flo
default:
confidence = 0.5
}
logging.Debugf("Qwen3Guard parser: Safety=%s, isJailbreak=%v, confidence=%.3f", safety, isJailbreak, confidence)
return isJailbreak, confidence

// Extract categories from output
categories := v.extractCategories(output)
logging.Debugf("Qwen3Guard parser: Safety=%s, isJailbreak=%v, confidence=%.3f, categories=%v",
safety, isJailbreak, confidence, categories)
return isJailbreak, confidence, categories
}

// Try "Severity level:" format (alternative format)
Expand All @@ -187,28 +204,62 @@ func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, flo
default:
confidence = 0.5
}
logging.Debugf("Qwen3Guard parser (severity): Severity=%s, isJailbreak=%v, confidence=%.3f", severity, isJailbreak, confidence)
return isJailbreak, confidence

// Extract categories from output
categories := v.extractCategories(output)
logging.Debugf("Qwen3Guard parser (severity): Severity=%s, isJailbreak=%v, confidence=%.3f, categories=%v",
severity, isJailbreak, confidence, categories)
return isJailbreak, confidence, categories
}

// Try category-based detection (Categories: or Category:)
categoryRegex := regexp.MustCompile(`(?i)categories?:\s*([^\n]+)`)
catMatches := categoryRegex.FindStringSubmatch(output)
if len(catMatches) >= 2 {
category := strings.ToLower(strings.TrimSpace(catMatches[1]))
// Check for jailbreak-related categories
if strings.Contains(category, "jailbreak") ||
strings.Contains(category, "illegal") ||
strings.Contains(category, "harmful") ||
strings.Contains(category, "violence") ||
strings.Contains(category, "hate") {
logging.Debugf("Qwen3Guard parser (category): Category=%s, isJailbreak=true, confidence=0.9", category)
return true, 0.9
// Extract categories even if safety level wasn't found
categories := v.extractCategories(output)
if len(categories) > 0 {
// Check if any category indicates unsafe content
categoryStr := strings.ToLower(strings.Join(categories, ", "))
if strings.Contains(categoryStr, "jailbreak") ||
strings.Contains(categoryStr, "illegal") ||
strings.Contains(categoryStr, "harmful") ||
strings.Contains(categoryStr, "violence") ||
strings.Contains(categoryStr, "hate") {
logging.Debugf("Qwen3Guard parser (category): Categories=%v, isJailbreak=true, confidence=0.9", categories)
return true, 0.9, categories
}
}

logging.Warnf("Qwen3Guard parser failed to parse output: %s", output)
return false, 0.0 // Failed to parse
return false, 0.0, nil // Failed to parse
}

// extractCategories extracts violation categories from Qwen3Guard output
// Returns empty slice if no categories found or if "None" is specified
func (v *VLLMJailbreakInference) extractCategories(output string) []string {
// Pattern matches: "Categories: Violent" or "Categories: Violent, Jailbreak"
categoryRegex := regexp.MustCompile(`(?i)categories?:\s*([^\n]+)`)
matches := categoryRegex.FindStringSubmatch(output)
if len(matches) < 2 {
return nil
}

categoryLine := strings.TrimSpace(matches[1])

// Handle "None" case
if strings.EqualFold(categoryLine, "None") {
return nil
}

// Split by comma and trim each category
parts := strings.Split(categoryLine, ",")
var categories []string
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" && !strings.EqualFold(trimmed, "None") {
categories = append(categories, trimmed)
}
}

return categories
}

// parseJSONFormat parses JSON output
Expand Down
Loading