Skip to content

Commit 69e8e92

Browse files
committed
finish pii tests
Signed-off-by: Alex Wang <[email protected]>
1 parent db44594 commit 69e8e92

File tree

2 files changed

+211
-22
lines changed

2 files changed

+211
-22
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassRes
2222
return candle_binding.ClassifyText(text)
2323
}
2424

25-
type ModernBERTCategoryInference struct{}
25+
type ModernBertCategoryInference struct{}
2626

27-
func (c *ModernBERTCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
27+
func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
2828
return candle_binding.ClassifyModernBertText(text)
2929
}
3030

@@ -38,12 +38,22 @@ func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassRe
3838
return candle_binding.ClassifyJailbreakText(text)
3939
}
4040

41-
type ModernBERTJailbreakInference struct{}
41+
type ModernBertJailbreakInference struct{}
4242

43-
func (c *ModernBERTJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
43+
func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
4444
return candle_binding.ClassifyModernBertJailbreakText(text)
4545
}
4646

47+
type PIIInference interface {
48+
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
49+
}
50+
51+
type ModernBertPIIInference struct{}
52+
53+
func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
54+
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
55+
}
56+
4757
// JailbreakDetection represents the result of jailbreak analysis for a piece of content
4858
type JailbreakDetection struct {
4959
Content string `json:"content"`
@@ -75,6 +85,7 @@ type Classifier struct {
7585
// Dependencies
7686
categoryInference CategoryInference
7787
jailbreakInference JailbreakInference
88+
piiInference PIIInference
7889

7990
Config *config.RouterConfig
8091
CategoryMapping *CategoryMapping
@@ -92,21 +103,25 @@ type Classifier struct {
92103
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
93104
var categoryInference CategoryInference
94105
if cfg.Classifier.CategoryModel.UseModernBERT {
95-
categoryInference = &ModernBERTCategoryInference{}
106+
categoryInference = &ModernBertCategoryInference{}
96107
} else {
97108
categoryInference = &LinearCategoryInference{}
98109
}
99110

100111
var jailbreakInference JailbreakInference
101112
if cfg.PromptGuard.UseModernBERT {
102-
jailbreakInference = &ModernBERTJailbreakInference{}
113+
jailbreakInference = &ModernBertJailbreakInference{}
103114
} else {
104115
jailbreakInference = &LinearJailbreakInference{}
105116
}
106117

118+
piiInference := &ModernBertPIIInference{}
119+
107120
return &Classifier{
108-
categoryInference: categoryInference,
109-
jailbreakInference: jailbreakInference,
121+
categoryInference: categoryInference,
122+
jailbreakInference: jailbreakInference,
123+
piiInference: piiInference,
124+
110125
Config: cfg,
111126
CategoryMapping: categoryMapping,
112127
PIIMapping: piiMapping,
@@ -289,7 +304,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) {
289304
// Use ModernBERT PII token classifier for entity detection
290305
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
291306
start := time.Now()
292-
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath)
307+
tokenResult, err := c.piiInference.ClassifyTokens(text, configPath)
293308
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
294309
if err != nil {
295310
return nil, fmt.Errorf("PII token classification error: %w", err)
@@ -371,7 +386,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal
371386
// Use ModernBERT PII token classifier for detailed analysis
372387
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
373388
start := time.Now()
374-
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath)
389+
tokenResult, err := c.piiInference.ClassifyTokens(content, configPath)
375390
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
376391
if err != nil {
377392
log.Printf("Error analyzing content %d: %v", i, err)

src/semantic-router/pkg/utils/classification/classifier_test.go

Lines changed: 186 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ var _ = Describe("ClassifyCategory", func() {
3434
BeforeEach(func() {
3535
mockCategoryModel = &MockCategoryInference{}
3636
cfg := &config.RouterConfig{}
37-
cfg.Classifier.CategoryModel.Threshold = 0.5 // Set threshold for testing
37+
cfg.Classifier.CategoryModel.Threshold = 0.5
3838

3939
classifier = &Classifier{
4040
categoryInference: mockCategoryModel,
@@ -55,7 +55,7 @@ var _ = Describe("ClassifyCategory", func() {
5555

5656
category, score, err := classifier.ClassifyCategory("This is about politics")
5757

58-
Expect(err).To(BeNil())
58+
Expect(err).ToNot(HaveOccurred())
5959
Expect(category).To(Equal("politics"))
6060
Expect(score).To(BeNumerically("~", 0.95, 0.001))
6161
})
@@ -70,7 +70,7 @@ var _ = Describe("ClassifyCategory", func() {
7070

7171
category, score, err := classifier.ClassifyCategory("Ambiguous text")
7272

73-
Expect(err).To(BeNil())
73+
Expect(err).ToNot(HaveOccurred())
7474
Expect(category).To(Equal(""))
7575
Expect(score).To(BeNumerically("~", 0.3, 0.001))
7676
})
@@ -82,7 +82,8 @@ var _ = Describe("ClassifyCategory", func() {
8282

8383
category, score, err := classifier.ClassifyCategory("Some text")
8484

85-
Expect(err).ToNot(BeNil())
85+
Expect(err).To(HaveOccurred())
86+
Expect(err.Error()).To(ContainSubstring("classification error"))
8687
Expect(category).To(Equal(""))
8788
Expect(score).To(BeNumerically("~", 0.0, 0.001))
8889
})
@@ -97,8 +98,7 @@ var _ = Describe("ClassifyCategory", func() {
9798

9899
category, score, err := classifier.ClassifyCategory("")
99100

100-
// Should still attempt classification
101-
Expect(err).To(BeNil())
101+
Expect(err).ToNot(HaveOccurred())
102102
Expect(category).To(Equal("technology"))
103103
Expect(score).To(BeNumerically("~", 0.8, 0.001))
104104
})
@@ -113,7 +113,7 @@ var _ = Describe("ClassifyCategory", func() {
113113

114114
category, score, err := classifier.ClassifyCategory("Some text")
115115

116-
Expect(err).To(BeNil())
116+
Expect(err).ToNot(HaveOccurred())
117117
Expect(category).To(Equal(""))
118118
Expect(score).To(BeNumerically("~", 0.8, 0.001))
119119
})
@@ -163,7 +163,7 @@ var _ = Describe("CheckForJailbreak", func() {
163163

164164
isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt")
165165

166-
Expect(err).To(BeNil())
166+
Expect(err).ToNot(HaveOccurred())
167167
Expect(isJailbreak).To(BeTrue())
168168
Expect(jailbreakType).To(Equal("jailbreak"))
169169
Expect(confidence).To(BeNumerically("~", 0.9, 0.001))
@@ -179,7 +179,7 @@ var _ = Describe("CheckForJailbreak", func() {
179179

180180
isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question")
181181

182-
Expect(err).To(BeNil())
182+
Expect(err).ToNot(HaveOccurred())
183183
Expect(isJailbreak).To(BeFalse())
184184
Expect(jailbreakType).To(Equal("benign"))
185185
Expect(confidence).To(BeNumerically("~", 0.9, 0.001))
@@ -195,7 +195,7 @@ var _ = Describe("CheckForJailbreak", func() {
195195

196196
isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text")
197197

198-
Expect(err).To(BeNil())
198+
Expect(err).ToNot(HaveOccurred())
199199
Expect(isJailbreak).To(BeFalse())
200200
Expect(jailbreakType).To(Equal("jailbreak"))
201201
Expect(confidence).To(BeNumerically("~", 0.5, 0.001))
@@ -208,7 +208,7 @@ var _ = Describe("CheckForJailbreak", func() {
208208

209209
isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text")
210210

211-
Expect(err).ToNot(BeNil())
211+
Expect(err).To(HaveOccurred())
212212
Expect(err.Error()).To(ContainSubstring("jailbreak classification failed"))
213213
Expect(isJailbreak).To(BeFalse())
214214
Expect(jailbreakType).To(Equal(""))
@@ -225,10 +225,184 @@ var _ = Describe("CheckForJailbreak", func() {
225225

226226
isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text")
227227

228-
Expect(err).ToNot(BeNil())
228+
Expect(err).To(HaveOccurred())
229+
Expect(err.Error()).To(ContainSubstring("unknown jailbreak class index"))
229230
Expect(isJailbreak).To(BeFalse())
230231
Expect(jailbreakType).To(Equal(""))
231232
Expect(confidence).To(BeNumerically("~", 0.0, 0.001))
232233
})
233234
})
234235
})
236+
237+
type PIIInferenceResponse struct {
238+
classifyTokensResult candle_binding.TokenClassificationResult
239+
classifyTokensError error
240+
}
241+
242+
type MockPIIInference struct {
243+
PIIInferenceResponse
244+
responseMap map[string]PIIInferenceResponse
245+
}
246+
247+
func (m *MockPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
248+
if response, exists := m.responseMap[text]; exists {
249+
return response.classifyTokensResult, response.classifyTokensError
250+
}
251+
return m.classifyTokensResult, m.classifyTokensError
252+
}
253+
254+
var _ = Describe("PIIClassification", func() {
255+
var (
256+
classifier *Classifier
257+
mockPIIModel *MockPIIInference
258+
)
259+
260+
BeforeEach(func() {
261+
mockPIIModel = &MockPIIInference{}
262+
cfg := &config.RouterConfig{}
263+
cfg.Classifier.PIIModel.ModelID = "test-pii-model"
264+
cfg.Classifier.PIIModel.Threshold = 0.7
265+
266+
classifier = &Classifier{
267+
piiInference: mockPIIModel,
268+
Config: cfg,
269+
PIIMapping: &PIIMapping{
270+
LabelToIdx: map[string]int{"PERSON": 0, "EMAIL": 1},
271+
IdxToLabel: map[string]string{"0": "PERSON", "1": "EMAIL"},
272+
},
273+
}
274+
})
275+
276+
Describe("ClassifyPII", func() {
277+
Context("when PII entities are detected above threshold", func() {
278+
It("should return detected PII types", func() {
279+
mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{
280+
Entities: []candle_binding.TokenEntity{
281+
{
282+
EntityType: "PERSON",
283+
Text: "John Doe",
284+
Start: 0,
285+
End: 8,
286+
Confidence: 0.9,
287+
},
288+
{
289+
EntityType: "EMAIL",
290+
291+
Start: 9,
292+
End: 25,
293+
Confidence: 0.8,
294+
},
295+
},
296+
}
297+
298+
piiTypes, err := classifier.ClassifyPII("John Doe [email protected]")
299+
300+
Expect(err).ToNot(HaveOccurred())
301+
Expect(piiTypes).To(ConsistOf("PERSON", "EMAIL"))
302+
})
303+
})
304+
305+
Context("when PII entities are detected below threshold", func() {
306+
It("should filter out low confidence entities", func() {
307+
mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{
308+
Entities: []candle_binding.TokenEntity{
309+
{
310+
EntityType: "PERSON",
311+
Text: "John Doe",
312+
Start: 0,
313+
End: 8,
314+
Confidence: 0.9,
315+
},
316+
{
317+
EntityType: "EMAIL",
318+
319+
Start: 9,
320+
End: 25,
321+
Confidence: 0.5,
322+
},
323+
},
324+
}
325+
326+
piiTypes, err := classifier.ClassifyPII("John Doe [email protected]")
327+
328+
Expect(err).ToNot(HaveOccurred())
329+
Expect(piiTypes).To(ConsistOf("PERSON"))
330+
})
331+
})
332+
333+
Context("when no PII is detected", func() {
334+
It("should return empty list", func() {
335+
mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{
336+
Entities: []candle_binding.TokenEntity{},
337+
}
338+
339+
piiTypes, err := classifier.ClassifyPII("Some text")
340+
341+
Expect(err).ToNot(HaveOccurred())
342+
Expect(piiTypes).To(BeEmpty())
343+
})
344+
})
345+
346+
Context("when model inference fails", func() {
347+
It("should return error", func() {
348+
mockPIIModel.classifyTokensError = errors.New("PII model inference failed")
349+
350+
piiTypes, err := classifier.ClassifyPII("Some text")
351+
352+
Expect(err).To(HaveOccurred())
353+
Expect(err.Error()).To(ContainSubstring("PII token classification error"))
354+
Expect(piiTypes).To(BeNil())
355+
})
356+
})
357+
})
358+
359+
Describe("AnalyzeContentForPII", func() {
360+
Context("when some texts contain PII", func() {
361+
It("should return detailed analysis for each text", func() {
362+
mockPIIModel.responseMap = make(map[string]PIIInferenceResponse)
363+
mockPIIModel.responseMap["Alice Smith"] = PIIInferenceResponse{
364+
classifyTokensResult: candle_binding.TokenClassificationResult{
365+
Entities: []candle_binding.TokenEntity{
366+
{
367+
EntityType: "PERSON",
368+
Text: "Alice",
369+
Start: 0,
370+
End: 5,
371+
Confidence: 0.9,
372+
},
373+
},
374+
},
375+
classifyTokensError: nil,
376+
}
377+
378+
mockPIIModel.responseMap["No PII here"] = PIIInferenceResponse{}
379+
380+
contentList := []string{"Alice Smith", "No PII here"}
381+
hasPII, results, err := classifier.AnalyzeContentForPII(contentList)
382+
383+
Expect(err).ToNot(HaveOccurred())
384+
Expect(hasPII).To(BeTrue())
385+
Expect(results).To(HaveLen(2))
386+
Expect(results[0].HasPII).To(BeTrue())
387+
Expect(results[0].Entities).To(HaveLen(1))
388+
Expect(results[0].Entities[0].EntityType).To(Equal("PERSON"))
389+
Expect(results[0].Entities[0].Text).To(Equal("Alice"))
390+
Expect(results[1].HasPII).To(BeFalse())
391+
Expect(results[1].Entities).To(BeEmpty())
392+
})
393+
})
394+
395+
Context("when model inference fails", func() {
396+
It("should return false for hasPII and empty results", func() {
397+
mockPIIModel.classifyTokensError = errors.New("model failed")
398+
399+
contentList := []string{"Text 1", "Text 2"}
400+
hasPII, results, err := classifier.AnalyzeContentForPII(contentList)
401+
402+
Expect(err).ToNot(HaveOccurred())
403+
Expect(hasPII).To(BeFalse())
404+
Expect(results).To(BeEmpty())
405+
})
406+
})
407+
})
408+
})

0 commit comments

Comments
 (0)