@@ -16,11 +16,6 @@ import (
1616 "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/pii"
1717)
1818
19- var (
20- initialized bool
21- initMutex sync.Mutex
22- )
23-
2419// OpenAIRouter is an Envoy ExtProc server that routes OpenAI API requests
2520type OpenAIRouter struct {
2621 Config * config.RouterConfig
@@ -48,9 +43,6 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
4843 // Update global config reference for packages that rely on config.GetConfig()
4944 config .ReplaceGlobalConfig (cfg )
5045
51- initMutex .Lock ()
52- defer initMutex .Unlock ()
53-
5446 // Load category mapping if classifier is enabled
5547 var categoryMapping * classification.CategoryMapping
5648 if cfg .Classifier .CategoryModel .CategoryMappingPath != "" {
@@ -81,11 +73,9 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
8173 log .Printf ("Loaded jailbreak mapping with %d jailbreak types" , jailbreakMapping .GetJailbreakTypeCount ())
8274 }
8375
84- if ! initialized {
85- if err := initializeModels (cfg , categoryMapping , piiMapping , jailbreakMapping ); err != nil {
86- return nil , err
87- }
88- initialized = true
76+ // Initialize the BERT model for similarity search
77+ if err := candle_binding .InitModel (cfg .BertModel .ModelID , cfg .BertModel .UseCPU ); err != nil {
78+ return nil , fmt .Errorf ("failed to initialize BERT model: %w" , err )
8979 }
9080
9181 categoryDescriptions := cfg .GetCategoryDescriptions ()
@@ -145,19 +135,15 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
145135
146136 // Create utility components
147137 piiChecker := pii .NewPolicyChecker (cfg , cfg .ModelConfig )
148- classifier := classification .NewClassifier (cfg , categoryMapping , piiMapping , jailbreakMapping )
138+
139+ classifier , err := classification .NewClassifier (cfg , categoryMapping , piiMapping , jailbreakMapping )
140+ if err != nil {
141+ return nil , fmt .Errorf ("failed to create classifier: %w" , err )
142+ }
149143
150144 // Create global classification service for API access
151145 services .NewClassificationService (classifier , cfg )
152146
153- // Initialize jailbreak classifier if enabled
154- if jailbreakMapping != nil {
155- err = classifier .InitializeJailbreakClassifier ()
156- if err != nil {
157- return nil , fmt .Errorf ("failed to initialize jailbreak classifier: %w" , err )
158- }
159- }
160-
161147 router := & OpenAIRouter {
162148 Config : cfg ,
163149 CategoryDescriptions : categoryDescriptions ,
@@ -173,98 +159,3 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
173159
174160 return router , nil
175161}
176-
177- // initializeModels initializes the BERT and classifier models
178- func initializeModels (cfg * config.RouterConfig , categoryMapping * classification.CategoryMapping , piiMapping * classification.PIIMapping , jailbreakMapping * classification.JailbreakMapping ) error {
179- // Initialize the BERT model for similarity search
180- err := candle_binding .InitModel (cfg .BertModel .ModelID , cfg .BertModel .UseCPU )
181- if err != nil {
182- return fmt .Errorf ("failed to initialize BERT model: %w" , err )
183- }
184-
185- // Initialize the classifier model if enabled
186- if categoryMapping != nil {
187- // Get the number of categories from the mapping
188- numClasses := categoryMapping .GetCategoryCount ()
189- if numClasses < 2 {
190- log .Printf ("Warning: Not enough categories for classification, need at least 2, got %d" , numClasses )
191- } else {
192- // Use the category classifier model
193- classifierModelID := cfg .Classifier .CategoryModel .ModelID
194- if classifierModelID == "" {
195- classifierModelID = cfg .BertModel .ModelID
196- }
197-
198- if cfg .Classifier .CategoryModel .UseModernBERT {
199- // Initialize ModernBERT classifier
200- err = candle_binding .InitModernBertClassifier (classifierModelID , cfg .Classifier .CategoryModel .UseCPU )
201- if err != nil {
202- return fmt .Errorf ("failed to initialize ModernBERT classifier model: %w" , err )
203- }
204- log .Printf ("Initialized ModernBERT category classifier (classes auto-detected from model)" )
205- } else {
206- // Initialize linear classifier
207- err = candle_binding .InitClassifier (classifierModelID , numClasses , cfg .Classifier .CategoryModel .UseCPU )
208- if err != nil {
209- return fmt .Errorf ("failed to initialize classifier model: %w" , err )
210- }
211- log .Printf ("Initialized linear category classifier with %d categories" , numClasses )
212- }
213- }
214- }
215-
216- // Initialize PII classifier if enabled
217- if piiMapping != nil {
218- // Get the number of PII types from the mapping
219- numPIIClasses := piiMapping .GetPIITypeCount ()
220- if numPIIClasses < 2 {
221- log .Printf ("Warning: Not enough PII types for classification, need at least 2, got %d" , numPIIClasses )
222- } else {
223- // Use the PII classifier model
224- piiClassifierModelID := cfg .Classifier .PIIModel .ModelID
225- if piiClassifierModelID == "" {
226- piiClassifierModelID = cfg .BertModel .ModelID
227- }
228-
229- // Initialize ModernBERT PII token classifier for entity detection
230- err = candle_binding .InitModernBertPIITokenClassifier (piiClassifierModelID , cfg .Classifier .PIIModel .UseCPU )
231- if err != nil {
232- return fmt .Errorf ("failed to initialize ModernBERT PII token classifier model: %w" , err )
233- }
234- log .Printf ("Initialized ModernBERT PII token classifier for entity detection" )
235- }
236- }
237-
238- // Initialize jailbreak classifier if enabled
239- if jailbreakMapping != nil {
240- // Get the number of jailbreak types from the mapping
241- numJailbreakClasses := jailbreakMapping .GetJailbreakTypeCount ()
242- if numJailbreakClasses < 2 {
243- log .Printf ("Warning: Not enough jailbreak types for classification, need at least 2, got %d" , numJailbreakClasses )
244- } else {
245- // Use the jailbreak classifier model
246- jailbreakClassifierModelID := cfg .PromptGuard .ModelID
247- if jailbreakClassifierModelID == "" {
248- jailbreakClassifierModelID = cfg .BertModel .ModelID
249- }
250-
251- if cfg .PromptGuard .UseModernBERT {
252- // Initialize ModernBERT jailbreak classifier
253- err = candle_binding .InitModernBertJailbreakClassifier (jailbreakClassifierModelID , cfg .PromptGuard .UseCPU )
254- if err != nil {
255- return fmt .Errorf ("failed to initialize ModernBERT jailbreak classifier model: %w" , err )
256- }
257- log .Printf ("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)" )
258- } else {
259- // Initialize linear jailbreak classifier
260- err = candle_binding .InitJailbreakClassifier (jailbreakClassifierModelID , numJailbreakClasses , cfg .PromptGuard .UseCPU )
261- if err != nil {
262- return fmt .Errorf ("failed to initialize jailbreak classifier model: %w" , err )
263- }
264- log .Printf ("Initialized linear jailbreak classifier with %d jailbreak types" , numJailbreakClasses )
265- }
266- }
267- }
268-
269- return nil
270- }
0 commit comments