Skip to content

Commit eca7b2c

Browse files
authored
refactor: move classifier model init to classifier.go and unify the classifier model init logic (vllm-project#113)
* refactor: unify model init for jailbreak Signed-off-by: Alex Wang <[email protected]> * refactor: unify model init for category/pii Signed-off-by: Alex Wang <[email protected]> --------- Signed-off-by: Alex Wang <[email protected]>
1 parent 11d3fc9 commit eca7b2c

File tree

5 files changed

+451
-381
lines changed

5 files changed

+451
-381
lines changed

src/semantic-router/pkg/extproc/router.go

Lines changed: 8 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ import (
1616
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/pii"
1717
)
1818

19-
var (
20-
initialized bool
21-
initMutex sync.Mutex
22-
)
23-
2419
// OpenAIRouter is an Envoy ExtProc server that routes OpenAI API requests
2520
type OpenAIRouter struct {
2621
Config *config.RouterConfig
@@ -48,9 +43,6 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
4843
// Update global config reference for packages that rely on config.GetConfig()
4944
config.ReplaceGlobalConfig(cfg)
5045

51-
initMutex.Lock()
52-
defer initMutex.Unlock()
53-
5446
// Load category mapping if classifier is enabled
5547
var categoryMapping *classification.CategoryMapping
5648
if cfg.Classifier.CategoryModel.CategoryMappingPath != "" {
@@ -81,11 +73,9 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
8173
log.Printf("Loaded jailbreak mapping with %d jailbreak types", jailbreakMapping.GetJailbreakTypeCount())
8274
}
8375

84-
if !initialized {
85-
if err := initializeModels(cfg, categoryMapping, piiMapping, jailbreakMapping); err != nil {
86-
return nil, err
87-
}
88-
initialized = true
76+
// Initialize the BERT model for similarity search
77+
if err := candle_binding.InitModel(cfg.BertModel.ModelID, cfg.BertModel.UseCPU); err != nil {
78+
return nil, fmt.Errorf("failed to initialize BERT model: %w", err)
8979
}
9080

9181
categoryDescriptions := cfg.GetCategoryDescriptions()
@@ -145,19 +135,15 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
145135

146136
// Create utility components
147137
piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig)
148-
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping)
138+
139+
classifier, err := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to create classifier: %w", err)
142+
}
149143

150144
// Create global classification service for API access
151145
services.NewClassificationService(classifier, cfg)
152146

153-
// Initialize jailbreak classifier if enabled
154-
if jailbreakMapping != nil {
155-
err = classifier.InitializeJailbreakClassifier()
156-
if err != nil {
157-
return nil, fmt.Errorf("failed to initialize jailbreak classifier: %w", err)
158-
}
159-
}
160-
161147
router := &OpenAIRouter{
162148
Config: cfg,
163149
CategoryDescriptions: categoryDescriptions,
@@ -173,98 +159,3 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
173159

174160
return router, nil
175161
}
176-
177-
// initializeModels initializes the BERT and classifier models
178-
func initializeModels(cfg *config.RouterConfig, categoryMapping *classification.CategoryMapping, piiMapping *classification.PIIMapping, jailbreakMapping *classification.JailbreakMapping) error {
179-
// Initialize the BERT model for similarity search
180-
err := candle_binding.InitModel(cfg.BertModel.ModelID, cfg.BertModel.UseCPU)
181-
if err != nil {
182-
return fmt.Errorf("failed to initialize BERT model: %w", err)
183-
}
184-
185-
// Initialize the classifier model if enabled
186-
if categoryMapping != nil {
187-
// Get the number of categories from the mapping
188-
numClasses := categoryMapping.GetCategoryCount()
189-
if numClasses < 2 {
190-
log.Printf("Warning: Not enough categories for classification, need at least 2, got %d", numClasses)
191-
} else {
192-
// Use the category classifier model
193-
classifierModelID := cfg.Classifier.CategoryModel.ModelID
194-
if classifierModelID == "" {
195-
classifierModelID = cfg.BertModel.ModelID
196-
}
197-
198-
if cfg.Classifier.CategoryModel.UseModernBERT {
199-
// Initialize ModernBERT classifier
200-
err = candle_binding.InitModernBertClassifier(classifierModelID, cfg.Classifier.CategoryModel.UseCPU)
201-
if err != nil {
202-
return fmt.Errorf("failed to initialize ModernBERT classifier model: %w", err)
203-
}
204-
log.Printf("Initialized ModernBERT category classifier (classes auto-detected from model)")
205-
} else {
206-
// Initialize linear classifier
207-
err = candle_binding.InitClassifier(classifierModelID, numClasses, cfg.Classifier.CategoryModel.UseCPU)
208-
if err != nil {
209-
return fmt.Errorf("failed to initialize classifier model: %w", err)
210-
}
211-
log.Printf("Initialized linear category classifier with %d categories", numClasses)
212-
}
213-
}
214-
}
215-
216-
// Initialize PII classifier if enabled
217-
if piiMapping != nil {
218-
// Get the number of PII types from the mapping
219-
numPIIClasses := piiMapping.GetPIITypeCount()
220-
if numPIIClasses < 2 {
221-
log.Printf("Warning: Not enough PII types for classification, need at least 2, got %d", numPIIClasses)
222-
} else {
223-
// Use the PII classifier model
224-
piiClassifierModelID := cfg.Classifier.PIIModel.ModelID
225-
if piiClassifierModelID == "" {
226-
piiClassifierModelID = cfg.BertModel.ModelID
227-
}
228-
229-
// Initialize ModernBERT PII token classifier for entity detection
230-
err = candle_binding.InitModernBertPIITokenClassifier(piiClassifierModelID, cfg.Classifier.PIIModel.UseCPU)
231-
if err != nil {
232-
return fmt.Errorf("failed to initialize ModernBERT PII token classifier model: %w", err)
233-
}
234-
log.Printf("Initialized ModernBERT PII token classifier for entity detection")
235-
}
236-
}
237-
238-
// Initialize jailbreak classifier if enabled
239-
if jailbreakMapping != nil {
240-
// Get the number of jailbreak types from the mapping
241-
numJailbreakClasses := jailbreakMapping.GetJailbreakTypeCount()
242-
if numJailbreakClasses < 2 {
243-
log.Printf("Warning: Not enough jailbreak types for classification, need at least 2, got %d", numJailbreakClasses)
244-
} else {
245-
// Use the jailbreak classifier model
246-
jailbreakClassifierModelID := cfg.PromptGuard.ModelID
247-
if jailbreakClassifierModelID == "" {
248-
jailbreakClassifierModelID = cfg.BertModel.ModelID
249-
}
250-
251-
if cfg.PromptGuard.UseModernBERT {
252-
// Initialize ModernBERT jailbreak classifier
253-
err = candle_binding.InitModernBertJailbreakClassifier(jailbreakClassifierModelID, cfg.PromptGuard.UseCPU)
254-
if err != nil {
255-
return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier model: %w", err)
256-
}
257-
log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)")
258-
} else {
259-
// Initialize linear jailbreak classifier
260-
err = candle_binding.InitJailbreakClassifier(jailbreakClassifierModelID, numJailbreakClasses, cfg.PromptGuard.UseCPU)
261-
if err != nil {
262-
return fmt.Errorf("failed to initialize jailbreak classifier model: %w", err)
263-
}
264-
log.Printf("Initialized linear jailbreak classifier with %d jailbreak types", numJailbreakClasses)
265-
}
266-
}
267-
}
268-
269-
return nil
270-
}

src/semantic-router/pkg/extproc/security_test.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ var _ = Describe("Security Checks", func() {
5252
},
5353
}
5454
router.PIIChecker = pii.NewPolicyChecker(cfg, cfg.ModelConfig)
55-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil)
55+
var err error
56+
router.Classifier, err = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil)
57+
Expect(err).NotTo(HaveOccurred())
5658
})
5759

5860
It("should allow requests with no PII", func() {
@@ -97,7 +99,8 @@ var _ = Describe("Security Checks", func() {
9799
piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath)
98100
Expect(err).NotTo(HaveOccurred())
99101

100-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
102+
router.Classifier, err = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
103+
Expect(err).NotTo(HaveOccurred())
101104
})
102105

103106
Describe("ClassifyPII method", func() {
@@ -339,7 +342,8 @@ var _ = Describe("Security Checks", func() {
339342
piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath)
340343
Expect(err).NotTo(HaveOccurred())
341344

342-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
345+
router.Classifier, err = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
346+
Expect(err).NotTo(HaveOccurred())
343347
})
344348

345349
Describe("Error handling and edge cases", func() {
@@ -516,15 +520,20 @@ var _ = Describe("Security Checks", func() {
516520
Context("with jailbreak detection enabled", func() {
517521
BeforeEach(func() {
518522
cfg.PromptGuard.Enabled = true
519-
cfg.PromptGuard.ModelID = "test-jailbreak-model"
523+
// TODO: Use a real model path here; this should be moved to an integration test later.
524+
cfg.PromptGuard.ModelID = "../../../../models/jailbreak_classifier_modernbert-base_model"
520525
cfg.PromptGuard.JailbreakMappingPath = "/path/to/jailbreak.json"
526+
cfg.PromptGuard.UseModernBERT = true
527+
cfg.PromptGuard.UseCPU = true
521528

522529
jailbreakMapping := &classification.JailbreakMapping{
523530
LabelToIdx: map[string]int{"benign": 0, "jailbreak": 1},
524531
IdxToLabel: map[string]string{"0": "benign", "1": "jailbreak"},
525532
}
526533

527-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping)
534+
var err error
535+
router.Classifier, err = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping)
536+
Expect(err).NotTo(HaveOccurred())
528537
})
529538

530539
It("should process potential jailbreak attempts", func() {

src/semantic-router/pkg/extproc/test_utils_test.go

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"fmt"
66
"io"
7-
"log"
87

98
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
109
"google.golang.org/grpc/metadata"
@@ -203,10 +202,9 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) {
203202
return nil, err
204203
}
205204

206-
// Initialize models using candle-binding
207-
err = initializeTestModels(cfg, categoryMapping, piiMapping)
208-
if err != nil {
209-
return nil, err
205+
// Initialize the BERT model for similarity search
206+
if err := candle_binding.InitModel(cfg.BertModel.ModelID, cfg.BertModel.UseCPU); err != nil {
207+
return nil, fmt.Errorf("failed to initialize BERT model: %w", err)
210208
}
211209

212210
// Create semantic cache
@@ -230,7 +228,10 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) {
230228
toolsDatabase := tools.NewToolsDatabase(toolsOptions)
231229

232230
// Create classifier
233-
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil)
231+
classifier, err := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil)
232+
if err != nil {
233+
return nil, err
234+
}
234235

235236
// Create PII checker
236237
piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig)
@@ -250,67 +251,3 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) {
250251

251252
return router, nil
252253
}
253-
254-
// initializeTestModels initializes the BERT and classifier models for testing
255-
func initializeTestModels(cfg *config.RouterConfig, categoryMapping *classification.CategoryMapping, piiMapping *classification.PIIMapping) error {
256-
// Initialize the BERT model for similarity search
257-
err := candle_binding.InitModel(cfg.BertModel.ModelID, cfg.BertModel.UseCPU)
258-
if err != nil {
259-
return fmt.Errorf("failed to initialize BERT model: %w", err)
260-
}
261-
262-
// Initialize the classifier model if enabled
263-
if categoryMapping != nil {
264-
// Get the number of categories from the mapping
265-
numClasses := categoryMapping.GetCategoryCount()
266-
if numClasses < 2 {
267-
log.Printf("Warning: Not enough categories for classification, need at least 2, got %d", numClasses)
268-
} else {
269-
// Use the category classifier model
270-
classifierModelID := cfg.Classifier.CategoryModel.ModelID
271-
if classifierModelID == "" {
272-
classifierModelID = cfg.BertModel.ModelID
273-
}
274-
275-
if cfg.Classifier.CategoryModel.UseModernBERT {
276-
// Initialize ModernBERT classifier
277-
err = candle_binding.InitModernBertClassifier(classifierModelID, cfg.Classifier.CategoryModel.UseCPU)
278-
if err != nil {
279-
return fmt.Errorf("failed to initialize ModernBERT classifier model: %w", err)
280-
}
281-
log.Printf("Initialized ModernBERT category classifier (classes auto-detected from model)")
282-
} else {
283-
// Initialize linear classifier
284-
err = candle_binding.InitClassifier(classifierModelID, numClasses, cfg.Classifier.CategoryModel.UseCPU)
285-
if err != nil {
286-
return fmt.Errorf("failed to initialize classifier model: %w", err)
287-
}
288-
log.Printf("Initialized linear category classifier with %d categories", numClasses)
289-
}
290-
}
291-
}
292-
293-
// Initialize PII classifier if enabled
294-
if piiMapping != nil {
295-
// Get the number of PII types from the mapping
296-
numPIIClasses := piiMapping.GetPIITypeCount()
297-
if numPIIClasses < 2 {
298-
log.Printf("Warning: Not enough PII types for classification, need at least 2, got %d", numPIIClasses)
299-
} else {
300-
// Use the PII classifier model
301-
piiClassifierModelID := cfg.Classifier.PIIModel.ModelID
302-
if piiClassifierModelID == "" {
303-
piiClassifierModelID = cfg.BertModel.ModelID
304-
}
305-
306-
// Initialize ModernBERT PII token classifier for entity detection
307-
err = candle_binding.InitModernBertPIITokenClassifier(piiClassifierModelID, cfg.Classifier.PIIModel.UseCPU)
308-
if err != nil {
309-
return fmt.Errorf("failed to initialize ModernBERT PII token classifier model: %w", err)
310-
}
311-
log.Printf("Initialized ModernBERT PII token classifier for entity detection")
312-
}
313-
}
314-
315-
return nil
316-
}

0 commit comments

Comments
 (0)