@@ -12,6 +12,69 @@ import (
1212 "github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
1313)
1414
15+ type CategoryInference interface {
16+ Classify (text string ) (candle_binding.ClassResult , error )
17+ }
18+
19+ type LinearCategoryInference struct {}
20+
21+ func (c * LinearCategoryInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
22+ return candle_binding .ClassifyText (text )
23+ }
24+
25+ type ModernBertCategoryInference struct {}
26+
27+ func (c * ModernBertCategoryInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
28+ return candle_binding .ClassifyModernBertText (text )
29+ }
30+
31+ // createCategoryInference creates the appropriate category inference based on configuration
32+ func createCategoryInference (useModernBERT bool ) CategoryInference {
33+ if useModernBERT {
34+ return & ModernBertCategoryInference {}
35+ }
36+ return & LinearCategoryInference {}
37+ }
38+
39+ type JailbreakInference interface {
40+ Classify (text string ) (candle_binding.ClassResult , error )
41+ }
42+
43+ type LinearJailbreakInference struct {}
44+
45+ func (c * LinearJailbreakInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
46+ return candle_binding .ClassifyJailbreakText (text )
47+ }
48+
49+ type ModernBertJailbreakInference struct {}
50+
51+ func (c * ModernBertJailbreakInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
52+ return candle_binding .ClassifyModernBertJailbreakText (text )
53+ }
54+
55+ // createJailbreakInference creates the appropriate jailbreak inference based on configuration
56+ func createJailbreakInference (useModernBERT bool ) JailbreakInference {
57+ if useModernBERT {
58+ return & ModernBertJailbreakInference {}
59+ }
60+ return & LinearJailbreakInference {}
61+ }
62+
63+ type PIIInference interface {
64+ ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error )
65+ }
66+
67+ type ModernBertPIIInference struct {}
68+
69+ func (c * ModernBertPIIInference ) ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error ) {
70+ return candle_binding .ClassifyModernBertPIITokens (text , configPath )
71+ }
72+
73+ // createPIIInference creates the appropriate PII inference (currently only ModernBERT)
74+ func createPIIInference () PIIInference {
75+ return & ModernBertPIIInference {}
76+ }
77+
1578// JailbreakDetection represents the result of jailbreak analysis for a piece of content
1679type JailbreakDetection struct {
1780 Content string `json:"content"`
@@ -40,6 +103,11 @@ type PIIAnalysisResult struct {
40103
41104// Classifier handles text classification, model selection, and jailbreak detection functionality
42105type Classifier struct {
106+ // Dependencies
107+ categoryInference CategoryInference
108+ jailbreakInference JailbreakInference
109+ piiInference PIIInference
110+
43111 Config * config.RouterConfig
44112 CategoryMapping * CategoryMapping
45113 PIIMapping * PIIMapping
@@ -55,6 +123,10 @@ type Classifier struct {
55123// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
56124func NewClassifier (cfg * config.RouterConfig , categoryMapping * CategoryMapping , piiMapping * PIIMapping , jailbreakMapping * JailbreakMapping , modelTTFT map [string ]float64 ) * Classifier {
57125 return & Classifier {
126+ categoryInference : createCategoryInference (cfg .Classifier .CategoryModel .UseModernBERT ),
127+ jailbreakInference : createJailbreakInference (cfg .PromptGuard .UseModernBERT ),
128+ piiInference : createPIIInference (),
129+
58130 Config : cfg ,
59131 CategoryMapping : categoryMapping ,
60132 PIIMapping : piiMapping ,
@@ -117,13 +189,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro
117189 var err error
118190
119191 start := time .Now ()
120- if c .Config .PromptGuard .UseModernBERT {
121- // Use ModernBERT jailbreak classifier
122- result , err = candle_binding .ClassifyModernBertJailbreakText (text )
123- } else {
124- // Use linear jailbreak classifier
125- result , err = candle_binding .ClassifyJailbreakText (text )
126- }
192+ result , err = c .jailbreakInference .Classify (text )
127193 metrics .RecordClassifierLatency ("jailbreak" , time .Since (start ).Seconds ())
128194
129195 if err != nil {
@@ -200,13 +266,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
200266 var err error
201267
202268 start := time .Now ()
203- if c .Config .Classifier .CategoryModel .UseModernBERT {
204- // Use ModernBERT classifier
205- result , err = candle_binding .ClassifyModernBertText (text )
206- } else {
207- // Use linear classifier
208- result , err = candle_binding .ClassifyText (text )
209- }
269+ result , err = c .categoryInference .Classify (text )
210270 metrics .RecordClassifierLatency ("category" , time .Since (start ).Seconds ())
211271
212272 if err != nil {
@@ -249,7 +309,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) {
249309 // Use ModernBERT PII token classifier for entity detection
250310 configPath := fmt .Sprintf ("%s/config.json" , c .Config .Classifier .PIIModel .ModelID )
251311 start := time .Now ()
252- tokenResult , err := candle_binding . ClassifyModernBertPIITokens (text , configPath )
312+ tokenResult , err := c . piiInference . ClassifyTokens (text , configPath )
253313 metrics .RecordClassifierLatency ("pii" , time .Since (start ).Seconds ())
254314 if err != nil {
255315 return nil , fmt .Errorf ("PII token classification error: %w" , err )
@@ -331,7 +391,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal
331391 // Use ModernBERT PII token classifier for detailed analysis
332392 configPath := fmt .Sprintf ("%s/config.json" , c .Config .Classifier .PIIModel .ModelID )
333393 start := time .Now ()
334- tokenResult , err := candle_binding . ClassifyModernBertPIITokens (content , configPath )
394+ tokenResult , err := c . piiInference . ClassifyTokens (content , configPath )
335395 metrics .RecordClassifierLatency ("pii" , time .Since (start ).Seconds ())
336396 if err != nil {
337397 log .Printf ("Error analyzing content %d: %v" , i , err )
0 commit comments