Skip to content

Commit 2eeeda8

Browse files
authored
refactor(config): restructure config to use nested model objects (#577)
* refactor(config): restructure config to use nested model objects This commit refactors the RouterConfig structure to organize related configuration fields into nested objects for better modularity and clarity: - BertModel: Groups BERT similarity model configuration (ModelID, Threshold, UseCPU) - MCPCategoryModel: Groups MCP category classifier configuration (Enabled, TransportType, Command, Args, etc.) - SemanticCache: Groups semantic cache configuration (Enabled, BackendType, MaxEntries, TTLSeconds, etc.) Updated all references across the codebase: - pkg/apiserver/route_model_info.go: Access BERT config via BertModel nested object - pkg/classification/classifier.go: Check MCP enabled via MCPCategoryModel.Enabled - pkg/classification/mcp_classifier.go: Access all MCP config via MCPCategoryModel - pkg/extproc/request_handler.go: Access cache config via SemanticCache.Enabled - pkg/extproc/router.go: Access BERT and cache config via nested objects - pkg/services/classification.go: Check MCP enabled via MCPCategoryModel.Enabled This improves code organization and makes the configuration structure more maintainable and easier to understand. Signed-off-by: bitliu <[email protected]> * test Signed-off-by: bitliu <[email protected]> --------- Signed-off-by: bitliu <[email protected]>
1 parent 6be225c commit 2eeeda8

File tree

7 files changed

+66
-67
lines changed

7 files changed

+66
-67
lines changed

src/semantic-router/pkg/apiserver/route_model_info.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,15 @@ func (s *ClassificationAPIServer) getLoadedModelsInfo() []ModelInfo {
130130
}
131131

132132
// BERT similarity model
133-
if s.config.ModelID != "" {
133+
if s.config.BertModel.ModelID != "" {
134134
models = append(models, ModelInfo{
135135
Name: "bert_similarity_model",
136136
Type: "similarity",
137137
Loaded: true,
138-
ModelPath: s.config.ModelID,
138+
ModelPath: s.config.BertModel.ModelID,
139139
Metadata: map[string]string{
140140
"model_type": "sentence_transformer",
141-
"threshold": fmt.Sprintf("%.2f", s.config.Threshold),
141+
"threshold": fmt.Sprintf("%.2f", s.config.BertModel.Threshold),
142142
"use_cpu": fmt.Sprintf("%t", s.config.BertModel.UseCPU),
143143
},
144144
})

src/semantic-router/pkg/classification/classifier.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
329329
// Note: Both in-tree and MCP classifiers can be configured simultaneously.
330330
// At runtime, in-tree classifier will be tried first, with MCP as a fallback.
331331
// This allows flexible deployment scenarios (e.g., gradual migration, A/B testing).
332-
if cfg.Enabled {
332+
if cfg.MCPCategoryModel.Enabled {
333333
mcpInit := createMCPCategoryInitializer()
334334
mcpInf := createMCPCategoryInference(mcpInit)
335335
options = append(options, withMCPCategory(mcpInit, mcpInf))

src/semantic-router/pkg/classification/classifier_test.go

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,10 +2335,38 @@ var _ = Describe("Classifier MCP Methods", func() {
23352335
BeforeEach(func() {
23362336
mockClient = &MockMCPClient{}
23372337
cfg := &config.RouterConfig{}
2338-
cfg.Enabled = true
2339-
cfg.ToolName = "classify_text"
2338+
cfg.MCPCategoryModel.Enabled = true
2339+
cfg.MCPCategoryModel.ToolName = "classify_text"
23402340
cfg.MCPCategoryModel.Threshold = 0.5
2341-
cfg.TimeoutSeconds = 30
2341+
cfg.MCPCategoryModel.TimeoutSeconds = 30
2342+
2343+
// Add Categories configuration for entropy-based tests
2344+
cfg.Categories = []config.Category{
2345+
{
2346+
CategoryMetadata: config.CategoryMetadata{Name: "tech"},
2347+
ModelScores: []config.ModelScore{{
2348+
Model: "phi4",
2349+
Score: 0.8,
2350+
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(false)},
2351+
}},
2352+
},
2353+
{
2354+
CategoryMetadata: config.CategoryMetadata{Name: "sports"},
2355+
ModelScores: []config.ModelScore{{
2356+
Model: "phi4",
2357+
Score: 0.8,
2358+
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(false)},
2359+
}},
2360+
},
2361+
{
2362+
CategoryMetadata: config.CategoryMetadata{Name: "politics"},
2363+
ModelScores: []config.ModelScore{{
2364+
Model: "deepseek-v31",
2365+
Score: 0.9,
2366+
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(true)},
2367+
}},
2368+
},
2369+
}
23422370

23432371
// Create MCP classifier manually and inject mock client
23442372
mcpClassifier := &MCPCategoryClassifier{
@@ -2374,7 +2402,7 @@ var _ = Describe("Classifier MCP Methods", func() {
23742402
})
23752403

23762404
It("should return false when not enabled", func() {
2377-
classifier.Config.Enabled = false
2405+
classifier.Config.MCPCategoryModel.Enabled = false
23782406
Expect(classifier.IsMCPCategoryEnabled()).To(BeFalse())
23792407
})
23802408

@@ -2387,7 +2415,7 @@ var _ = Describe("Classifier MCP Methods", func() {
23872415
Describe("classifyCategoryMCP", func() {
23882416
Context("when MCP is not enabled", func() {
23892417
It("should return error", func() {
2390-
classifier.Config.Enabled = false
2418+
classifier.Config.MCPCategoryModel.Enabled = false
23912419
_, _, err := classifier.classifyCategoryMCP("test text")
23922420
Expect(err).To(HaveOccurred())
23932421
Expect(err.Error()).To(ContainSubstring("not properly configured"))
@@ -2463,35 +2491,6 @@ var _ = Describe("Classifier MCP Methods", func() {
24632491
})
24642492

24652493
Describe("classifyCategoryWithEntropyMCP", func() {
2466-
BeforeEach(func() {
2467-
classifier.Config.Categories = []config.Category{
2468-
{
2469-
CategoryMetadata: config.CategoryMetadata{Name: "tech"},
2470-
ModelScores: []config.ModelScore{{
2471-
Model: "phi4",
2472-
Score: 0.8,
2473-
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(false)},
2474-
}},
2475-
},
2476-
{
2477-
CategoryMetadata: config.CategoryMetadata{Name: "sports"},
2478-
ModelScores: []config.ModelScore{{
2479-
Model: "phi4",
2480-
Score: 0.8,
2481-
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(false)},
2482-
}},
2483-
},
2484-
{
2485-
CategoryMetadata: config.CategoryMetadata{Name: "politics"},
2486-
ModelScores: []config.ModelScore{{
2487-
Model: "deepseek-v31",
2488-
Score: 0.9,
2489-
ModelReasoningControl: config.ModelReasoningControl{UseReasoning: lo.ToPtr(true)},
2490-
}},
2491-
},
2492-
}
2493-
})
2494-
24952494
Context("when MCP returns probabilities", func() {
24962495
It("should return category with entropy decision", func() {
24972496
mockClient.callToolResult = &mcp.CallToolResult{
@@ -2536,7 +2535,7 @@ var _ = Describe("Classifier MCP Methods", func() {
25362535
Describe("initializeMCPCategoryClassifier", func() {
25372536
Context("when MCP is not enabled", func() {
25382537
It("should return error", func() {
2539-
classifier.Config.Enabled = false
2538+
classifier.Config.MCPCategoryModel.Enabled = false
25402539
err := classifier.initializeMCPCategoryClassifier()
25412540
Expect(err).To(HaveOccurred())
25422541
Expect(err.Error()).To(ContainSubstring("not properly configured"))

src/semantic-router/pkg/classification/mcp_classifier.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (m *MCPCategoryClassifier) Init(cfg *config.RouterConfig) error {
7575
}
7676

7777
// Validate MCP configuration
78-
if !cfg.Enabled {
78+
if !cfg.MCPCategoryModel.Enabled {
7979
return fmt.Errorf("MCP category classifier is not enabled")
8080
}
8181

@@ -84,19 +84,19 @@ func (m *MCPCategoryClassifier) Init(cfg *config.RouterConfig) error {
8484

8585
// Create MCP client configuration
8686
mcpConfig := mcpclient.ClientConfig{
87-
TransportType: cfg.TransportType,
88-
Command: cfg.Command,
89-
Args: cfg.Args,
90-
Env: cfg.Env,
91-
URL: cfg.URL,
87+
TransportType: cfg.MCPCategoryModel.TransportType,
88+
Command: cfg.MCPCategoryModel.Command,
89+
Args: cfg.MCPCategoryModel.Args,
90+
Env: cfg.MCPCategoryModel.Env,
91+
URL: cfg.MCPCategoryModel.URL,
9292
Options: mcpclient.ClientOptions{
9393
LogEnabled: true,
9494
},
9595
}
9696

9797
// Set timeout if specified
98-
if cfg.TimeoutSeconds > 0 {
99-
mcpConfig.Timeout = time.Duration(cfg.TimeoutSeconds) * time.Second
98+
if cfg.MCPCategoryModel.TimeoutSeconds > 0 {
99+
mcpConfig.Timeout = time.Duration(cfg.MCPCategoryModel.TimeoutSeconds) * time.Second
100100
}
101101

102102
// Create MCP client
@@ -125,8 +125,8 @@ func (m *MCPCategoryClassifier) Init(cfg *config.RouterConfig) error {
125125
// discoverClassificationTool finds the appropriate classification tool from available MCP tools
126126
func (m *MCPCategoryClassifier) discoverClassificationTool() error {
127127
// If tool name is explicitly specified, use it
128-
if m.config.ToolName != "" {
129-
m.toolName = m.config.ToolName
128+
if m.config.MCPCategoryModel.ToolName != "" {
129+
m.toolName = m.config.MCPCategoryModel.ToolName
130130
logging.Infof("Using explicitly configured tool: %s", m.toolName)
131131
return nil
132132
}
@@ -356,7 +356,7 @@ func createMCPCategoryInference(initializer MCPCategoryInitializer) MCPCategoryI
356356
// IsMCPCategoryEnabled checks if MCP-based category classification is properly configured.
357357
// Note: tool_name is optional and will be auto-discovered during initialization if not specified.
358358
func (c *Classifier) IsMCPCategoryEnabled() bool {
359-
return c.Config.Enabled
359+
return c.Config.MCPCategoryModel.Enabled
360360
}
361361

362362
// initializeMCPCategoryClassifier initializes the MCP category classification model
@@ -380,9 +380,9 @@ func (c *Classifier) initializeMCPCategoryClassifier() error {
380380

381381
// Create a context with timeout for the list_categories call
382382
ctx := context.Background()
383-
if c.Config.TimeoutSeconds > 0 {
383+
if c.Config.MCPCategoryModel.TimeoutSeconds > 0 {
384384
var cancel context.CancelFunc
385-
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.TimeoutSeconds)*time.Second)
385+
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.MCPCategoryModel.TimeoutSeconds)*time.Second)
386386
defer cancel()
387387
}
388388

@@ -422,9 +422,9 @@ func (c *Classifier) classifyCategoryMCPWithRouting(text string) (*MCPClassifica
422422

423423
// Create context with timeout
424424
ctx := context.Background()
425-
if c.Config.TimeoutSeconds > 0 {
425+
if c.Config.MCPCategoryModel.TimeoutSeconds > 0 {
426426
var cancel context.CancelFunc
427-
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.TimeoutSeconds)*time.Second)
427+
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.MCPCategoryModel.TimeoutSeconds)*time.Second)
428428
defer cancel()
429429
}
430430

@@ -530,9 +530,9 @@ func (c *Classifier) classifyCategoryWithEntropyMCP(text string) (string, float6
530530

531531
// Create context with timeout
532532
ctx := context.Background()
533-
if c.Config.TimeoutSeconds > 0 {
533+
if c.Config.MCPCategoryModel.TimeoutSeconds > 0 {
534534
var cancel context.CancelFunc
535-
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.TimeoutSeconds)*time.Second)
535+
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.MCPCategoryModel.TimeoutSeconds)*time.Second)
536536
defer cancel()
537537
}
538538

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ func (r *OpenAIRouter) handleCaching(ctx *RequestContext, categoryName string) (
520520
ctx.RequestQuery = requestQuery
521521

522522
// Check if caching is enabled for this category
523-
cacheEnabled := r.Config.Enabled
523+
cacheEnabled := r.Config.SemanticCache.Enabled
524524
if categoryName != "" {
525525
cacheEnabled = r.Config.IsCacheEnabledForCategory(categoryName)
526526
}

src/semantic-router/pkg/extproc/router.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
6969
}
7070

7171
// Initialize the BERT model for similarity search
72-
if initErr := candle_binding.InitModel(cfg.ModelID, cfg.BertModel.UseCPU); initErr != nil {
72+
if initErr := candle_binding.InitModel(cfg.BertModel.ModelID, cfg.BertModel.UseCPU); initErr != nil {
7373
return nil, fmt.Errorf("failed to initialize BERT model: %w", initErr)
7474
}
7575

@@ -78,14 +78,14 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
7878

7979
// Create semantic cache with config options
8080
cacheConfig := cache.CacheConfig{
81-
BackendType: cache.CacheBackendType(cfg.BackendType),
82-
Enabled: cfg.Enabled,
81+
BackendType: cache.CacheBackendType(cfg.SemanticCache.BackendType),
82+
Enabled: cfg.SemanticCache.Enabled,
8383
SimilarityThreshold: cfg.GetCacheSimilarityThreshold(),
84-
MaxEntries: cfg.MaxEntries,
85-
TTLSeconds: cfg.TTLSeconds,
86-
EvictionPolicy: cache.EvictionPolicyType(cfg.EvictionPolicy),
87-
BackendConfigPath: cfg.BackendConfigPath,
88-
EmbeddingModel: cfg.EmbeddingModel,
84+
MaxEntries: cfg.SemanticCache.MaxEntries,
85+
TTLSeconds: cfg.SemanticCache.TTLSeconds,
86+
EvictionPolicy: cache.EvictionPolicyType(cfg.SemanticCache.EvictionPolicy),
87+
BackendConfigPath: cfg.SemanticCache.BackendConfigPath,
88+
EmbeddingModel: cfg.SemanticCache.EmbeddingModel,
8989
}
9090

9191
// Use default backend type if not specified
@@ -109,7 +109,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
109109
}
110110

111111
// Create tools database with config options
112-
toolsThreshold := cfg.Threshold // Default to BERT threshold
112+
toolsThreshold := cfg.BertModel.Threshold // Default to BERT threshold
113113
if cfg.Tools.SimilarityThreshold != nil {
114114
toolsThreshold = *cfg.Tools.SimilarityThreshold
115115
}

src/semantic-router/pkg/services/classification.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func createLegacyClassifier(config *config.RouterConfig) (*classification.Classi
8787
// Check if we should load categories from MCP server
8888
// Note: tool_name is optional and will be auto-discovered if not specified
8989
useMCPCategories := config.CategoryModel.ModelID == "" &&
90-
config.Enabled
90+
config.MCPCategoryModel.Enabled
9191

9292
if useMCPCategories {
9393
// Categories will be loaded from MCP server during initialization

0 commit comments

Comments
 (0)