@@ -11,6 +11,7 @@ import (
1111 candle_binding "github.com/vllm-project/semantic-router/candle-binding"
1212 "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
1313 mcpclient "github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp"
14+ "github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp/api"
1415 "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics"
1516 "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability"
1617 "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/entropy"
@@ -45,7 +46,21 @@ type MCPCategoryInference interface {
4546 ListCategories (ctx context.Context ) (* CategoryMapping , error )
4647}
4748
48- // MCPCategoryClassifier implements both MCPCategoryInitializer and MCPCategoryInference
49+ // MCPCategoryClassifier implements both MCPCategoryInitializer and MCPCategoryInference.
50+ //
51+ // Protocol Contract:
52+ // This client relies on the MCP server to respect the protocol defined in the
53+ // github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp/api package.
54+ //
55+ // The MCP server must implement these tools:
56+ // 1. list_categories - Returns api.ListCategoriesResponse
57+ // 2. classify_text - Returns api.ClassifyResponse or api.ClassifyWithProbabilitiesResponse
58+ //
59+ // The MCP server controls both classification AND routing decisions. When the server returns
60+ // "model" and "use_reasoning" in the classification response, the router will use those values.
61+ // If not provided, the router falls back to the default_model configuration.
62+ //
63+ // For detailed type definitions and examples, see the api package documentation.
4964type MCPCategoryClassifier struct {
5065 client mcpclient.MCPClient
5166 toolName string
@@ -202,13 +217,8 @@ func (m *MCPCategoryClassifier) Classify(ctx context.Context, text string) (cand
202217 return candle_binding.ClassResult {}, fmt .Errorf ("MCP tool returned non-text content" )
203218 }
204219
205- // Parse JSON response: {"class": int, "confidence": float, "model": str, "use_reasoning": bool}
206- var response struct {
207- Class int `json:"class"`
208- Confidence float32 `json:"confidence"`
209- Model string `json:"model,omitempty"`
210- UseReasoning * bool `json:"use_reasoning,omitempty"`
211- }
220+ // Parse JSON response using the API type
221+ var response api.ClassifyResponse
212222 if err := json .Unmarshal ([]byte (responseText ), & response ); err != nil {
213223 return candle_binding.ClassResult {}, fmt .Errorf ("failed to parse MCP response: %w" , err )
214224 }
@@ -256,14 +266,8 @@ func (m *MCPCategoryClassifier) ClassifyWithProbabilities(ctx context.Context, t
256266 return candle_binding.ClassResultWithProbs {}, fmt .Errorf ("MCP tool returned non-text content" )
257267 }
258268
259- // Parse JSON response: {"class": int, "confidence": float, "probabilities": []float, "model": str, "use_reasoning": bool}
260- var response struct {
261- Class int `json:"class"`
262- Confidence float32 `json:"confidence"`
263- Probabilities []float32 `json:"probabilities"`
264- Model string `json:"model,omitempty"`
265- UseReasoning * bool `json:"use_reasoning,omitempty"`
266- }
269+ // Parse JSON response using the API type
270+ var response api.ClassifyWithProbabilitiesResponse
267271 if err := json .Unmarshal ([]byte (responseText ), & response ); err != nil {
268272 return candle_binding.ClassResultWithProbs {}, fmt .Errorf ("failed to parse MCP response: %w" , err )
269273 }
@@ -306,10 +310,8 @@ func (m *MCPCategoryClassifier) ListCategories(ctx context.Context) (*CategoryMa
306310 return nil , fmt .Errorf ("MCP tool returned non-text content" )
307311 }
308312
309- // Parse JSON response: {"categories": ["cat1", "cat2", ...]}
310- var response struct {
311- Categories []string `json:"categories"`
312- }
313+ // Parse JSON response using the API type
314+ var response api.ListCategoriesResponse
313315 if err := json .Unmarshal ([]byte (responseText ), & response ); err != nil {
314316 return nil , fmt .Errorf ("failed to parse MCP categories response: %w" , err )
315317 }
@@ -342,10 +344,10 @@ func createMCPCategoryInference(initializer MCPCategoryInitializer) MCPCategoryI
342344 return nil
343345}
344346
345- // IsMCPCategoryEnabled checks if MCP-based category classification is properly configured
347+ // IsMCPCategoryEnabled checks if MCP-based category classification is properly configured.
348+ // Note: tool_name is optional and will be auto-discovered during initialization if not specified.
346349func (c * Classifier ) IsMCPCategoryEnabled () bool {
347- return c .Config .Classifier .MCPCategoryModel .Enabled &&
348- c .Config .Classifier .MCPCategoryModel .ToolName != ""
350+ return c .Config .Classifier .MCPCategoryModel .Enabled
349351}
350352
351353// initializeMCPCategoryClassifier initializes the MCP category classification model
@@ -455,13 +457,8 @@ func (c *Classifier) classifyCategoryMCPWithRouting(text string) (*MCPClassifica
455457 return nil , fmt .Errorf ("MCP tool returned non-text content" )
456458 }
457459
458- // Parse JSON response with routing information
459- var response struct {
460- Class int `json:"class"`
461- Confidence float32 `json:"confidence"`
462- Model string `json:"model,omitempty"`
463- UseReasoning * bool `json:"use_reasoning,omitempty"`
464- }
460+ // Parse JSON response with routing information using the API type
461+ var response api.ClassifyResponse
465462 if err := json .Unmarshal ([]byte (responseText ), & response ); err != nil {
466463 return nil , fmt .Errorf ("failed to parse MCP response: %w" , err )
467464 }
0 commit comments