Skip to content

Commit 75defeb

Browse files
committed
Add Qwen3Guard category extraction support
Signed-off-by: Yue Zhu <[email protected]>
1 parent 30801fa commit 75defeb

File tree

2 files changed

+87
-35
lines changed

2 files changed

+87
-35
lines changed

candle-binding/semantic-router.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ type SimResult struct {
310310

311311
// ClassResult represents the result of a text classification
312312
type ClassResult struct {
313-
Class int // Class index
314-
Confidence float32 // Confidence score
313+
Class int // Class index
314+
Confidence float32 // Confidence score
315+
Categories []string // Violation categories (e.g., "Violent", "Jailbreak") - only populated when unsafe/controversial
315316
}
316317

317318
// ClassResultWithProbs represents the result of a text classification with full probability distribution

src/semantic-router/pkg/classification/vllm_classifier.go

Lines changed: 84 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ func (v *VLLMJailbreakInference) Classify(text string) (candle_binding.ClassResu
7676
// Parse model output - flexible to support multiple formats
7777
output := resp.Choices[0].Message.Content
7878
logging.Debugf("vLLM jailbreak detection response: %s", output)
79-
isJailbreak, confidence := v.parseSafetyOutput(output)
80-
logging.Debugf("Parsed result: isJailbreak=%v, confidence=%.3f", isJailbreak, confidence)
79+
isJailbreak, confidence, categories := v.parseSafetyOutput(output)
80+
logging.Debugf("Parsed result: isJailbreak=%v, confidence=%.3f, categories=%v",
81+
isJailbreak, confidence, categories)
8182

8283
// Map to ClassResult format
8384
// Class: 0 = safe, 1 = jailbreak/unsafe
@@ -86,43 +87,55 @@ func (v *VLLMJailbreakInference) Classify(text string) (candle_binding.ClassResu
8687
class = 1
8788
}
8889

89-
return candle_binding.ClassResult{
90+
result := candle_binding.ClassResult{
9091
Class: class,
9192
Confidence: confidence,
92-
}, nil
93+
}
94+
95+
// Only populate categories when content is unsafe or controversial
96+
// (empty slice for safe content or when categories not available)
97+
if isJailbreak && len(categories) > 0 {
98+
result.Categories = categories
99+
}
100+
101+
return result, nil
93102
}
94103

95104
// parseSafetyOutput parses safety model output - uses parser type or auto-detection
96-
func (v *VLLMJailbreakInference) parseSafetyOutput(output string) (bool, float32) {
105+
func (v *VLLMJailbreakInference) parseSafetyOutput(output string) (bool, float32, []string) {
97106
// Determine parser type based on configuration or model name
98107
parserType := v.determineParserType()
99108

100109
switch parserType {
101110
case "qwen3guard":
102111
return v.parseQwen3GuardFormat(output)
103112
case "json":
104-
return v.parseJSONFormat(output)
113+
isJailbreak, conf := v.parseJSONFormat(output)
114+
return isJailbreak, conf, nil // JSON parser doesn't support categories yet
105115
case "simple":
106-
return v.parseSimpleFormat(output)
116+
isJailbreak, conf := v.parseSimpleFormat(output)
117+
return isJailbreak, conf, nil // Simple parser doesn't support categories yet
107118
case "auto":
108119
// Try all parsers (OR logic) until one succeeds
109-
if result, conf := v.parseQwen3GuardFormat(output); conf > 0.1 {
110-
return result, conf
120+
if result, conf, cats := v.parseQwen3GuardFormat(output); conf > 0.1 {
121+
return result, conf, cats
111122
}
112123
if result, conf := v.parseJSONFormat(output); conf > 0.1 {
113-
return result, conf
124+
return result, conf, nil
114125
}
115-
return v.parseSimpleFormat(output)
126+
isJailbreak, conf := v.parseSimpleFormat(output)
127+
return isJailbreak, conf, nil
116128
default:
117129
logging.Warnf("Unknown parser type: %s, using auto", parserType)
118130
// Fallback to auto mode
119-
if result, conf := v.parseQwen3GuardFormat(output); conf > 0.1 {
120-
return result, conf
131+
if result, conf, cats := v.parseQwen3GuardFormat(output); conf > 0.1 {
132+
return result, conf, cats
121133
}
122134
if result, conf := v.parseJSONFormat(output); conf > 0.1 {
123-
return result, conf
135+
return result, conf, nil
124136
}
125-
return v.parseSimpleFormat(output)
137+
isJailbreak, conf := v.parseSimpleFormat(output)
138+
return isJailbreak, conf, nil
126139
}
127140
}
128141

@@ -147,7 +160,7 @@ func (v *VLLMJailbreakInference) determineParserType() string {
147160
}
148161

149162
// parseQwen3GuardFormat parses Qwen3Guard structured output
150-
func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, float32) {
163+
func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, float32, []string) {
151164
// Qwen3Guard actual format: "Safety: Unsafe\nCategories: ..."
152165
// Try "Safety:" format first (actual Qwen3Guard output)
153166
safetyRegex := regexp.MustCompile(`(?i)safety:\s*(safe|unsafe|controversial)`)
@@ -166,8 +179,12 @@ func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, flo
166179
default:
167180
confidence = 0.5
168181
}
169-
logging.Debugf("Qwen3Guard parser: Safety=%s, isJailbreak=%v, confidence=%.3f", safety, isJailbreak, confidence)
170-
return isJailbreak, confidence
182+
183+
// Extract categories from output
184+
categories := v.extractCategories(output)
185+
logging.Debugf("Qwen3Guard parser: Safety=%s, isJailbreak=%v, confidence=%.3f, categories=%v",
186+
safety, isJailbreak, confidence, categories)
187+
return isJailbreak, confidence, categories
171188
}
172189

173190
// Try "Severity level:" format (alternative format)
@@ -187,28 +204,62 @@ func (v *VLLMJailbreakInference) parseQwen3GuardFormat(output string) (bool, flo
187204
default:
188205
confidence = 0.5
189206
}
190-
logging.Debugf("Qwen3Guard parser (severity): Severity=%s, isJailbreak=%v, confidence=%.3f", severity, isJailbreak, confidence)
191-
return isJailbreak, confidence
207+
208+
// Extract categories from output
209+
categories := v.extractCategories(output)
210+
logging.Debugf("Qwen3Guard parser (severity): Severity=%s, isJailbreak=%v, confidence=%.3f, categories=%v",
211+
severity, isJailbreak, confidence, categories)
212+
return isJailbreak, confidence, categories
192213
}
193214

194215
// Try category-based detection (Categories: or Category:)
195-
categoryRegex := regexp.MustCompile(`(?i)categories?:\s*([^\n]+)`)
196-
catMatches := categoryRegex.FindStringSubmatch(output)
197-
if len(catMatches) >= 2 {
198-
category := strings.ToLower(strings.TrimSpace(catMatches[1]))
199-
// Check for jailbreak-related categories
200-
if strings.Contains(category, "jailbreak") ||
201-
strings.Contains(category, "illegal") ||
202-
strings.Contains(category, "harmful") ||
203-
strings.Contains(category, "violence") ||
204-
strings.Contains(category, "hate") {
205-
logging.Debugf("Qwen3Guard parser (category): Category=%s, isJailbreak=true, confidence=0.9", category)
206-
return true, 0.9
216+
// Extract categories even if safety level wasn't found
217+
categories := v.extractCategories(output)
218+
if len(categories) > 0 {
219+
// Check if any category indicates unsafe content
220+
categoryStr := strings.ToLower(strings.Join(categories, ", "))
221+
if strings.Contains(categoryStr, "jailbreak") ||
222+
strings.Contains(categoryStr, "illegal") ||
223+
strings.Contains(categoryStr, "harmful") ||
224+
strings.Contains(categoryStr, "violence") ||
225+
strings.Contains(categoryStr, "hate") {
226+
logging.Debugf("Qwen3Guard parser (category): Categories=%v, isJailbreak=true, confidence=0.9", categories)
227+
return true, 0.9, categories
207228
}
208229
}
209230

210231
logging.Warnf("Qwen3Guard parser failed to parse output: %s", output)
211-
return false, 0.0 // Failed to parse
232+
return false, 0.0, nil // Failed to parse
233+
}
234+
235+
// extractCategories extracts violation categories from Qwen3Guard output
236+
// Returns empty slice if no categories found or if "None" is specified
237+
func (v *VLLMJailbreakInference) extractCategories(output string) []string {
238+
// Pattern matches: "Categories: Violent" or "Categories: Violent, Jailbreak"
239+
categoryRegex := regexp.MustCompile(`(?i)categories?:\s*([^\n]+)`)
240+
matches := categoryRegex.FindStringSubmatch(output)
241+
if len(matches) < 2 {
242+
return nil
243+
}
244+
245+
categoryLine := strings.TrimSpace(matches[1])
246+
247+
// Handle "None" case
248+
if strings.EqualFold(categoryLine, "None") {
249+
return nil
250+
}
251+
252+
// Split by comma and trim each category
253+
parts := strings.Split(categoryLine, ",")
254+
var categories []string
255+
for _, part := range parts {
256+
trimmed := strings.TrimSpace(part)
257+
if trimmed != "" && !strings.EqualFold(trimmed, "None") {
258+
categories = append(categories, trimmed)
259+
}
260+
}
261+
262+
return categories
212263
}
213264

214265
// parseJSONFormat parses JSON output

0 commit comments

Comments
 (0)