diff --git a/config/config-mcp-classifier-example.yaml b/config/config-mcp-classifier-example.yaml new file mode 100644 index 00000000..8fc25e37 --- /dev/null +++ b/config/config-mcp-classifier-example.yaml @@ -0,0 +1,168 @@ +# Example Configuration for MCP-Based Category Classifier (HTTP Transport) +# +# This configuration demonstrates how to use an external MCP (Model Context Protocol) +# service via HTTP for category classification instead of the built-in Candle/ModernBERT models. +# +# Use cases: +# - Offload classification to a remote HTTP service +# - Use custom classification models not supported in-tree +# - Scale classification independently from the router +# - Integrate with existing ML infrastructure via REST API +# +# Note: This example uses HTTP transport. The MCP server should expose an HTTP endpoint +# that implements the MCP protocol (e.g., http://localhost:8080/mcp) + +# BERT model for semantic caching and tool selection +bert_model: + model_id: "sentence-transformers/all-MiniLM-L6-v2" + threshold: 0.85 + use_cpu: true + +# Classifier configuration +classifier: + # Disable in-tree category classifier (leave model_id empty) + category_model: + model_id: "" # Empty = disabled + + # Enable MCP-based category classifier (HTTP transport only) + mcp_category_model: + enabled: true # Enable MCP classifier + transport_type: "http" # HTTP transport + url: "http://localhost:8090/mcp" # MCP server endpoint + + # tool_name: Optional - auto-discovers classification tool if not specified + # Will search for tools like: classify_text, classify, categorize, etc. + # Uncomment to explicitly specify: + # tool_name: "classify_text" + + threshold: 0.6 # Confidence threshold + timeout_seconds: 30 # Request timeout + +# Categories for routing queries +# +# Categories are automatically loaded from MCP server via 'list_categories' tool. +# The MCP server controls BOTH classification AND routing decisions. +# +# How it works: +# 1. Router connects to MCP server at startup +# 2. Calls 'list_categories' tool: MCP returns {"categories": ["business", "law", ...]} +# 3. For each request, calls 'classify_text' tool which returns: +# { +# "class": 3, +# "confidence": 0.85, +# "model": "openai/gpt-oss-20b", # MCP decides which model to use +# "use_reasoning": true # MCP decides whether to use reasoning +# } +# 4. Router uses the model and reasoning settings from MCP response +# +# BENEFITS: +# - MCP server makes intelligent routing decisions per query +# - No hardcoded routing rules needed in config +# - MCP can adapt routing based on query complexity, content, etc. +# - Centralized routing logic in MCP server +# +# FALLBACK: +# - If MCP doesn't return model/use_reasoning, uses default_model below +# - Can also add category-specific overrides here if needed +# +categories: [] + +# Default model to use when category can't be determined +default_model: openai/gpt-oss-20b + +# vLLM endpoints configuration +vllm_endpoints: + - name: endpoint1 + address: 127.0.0.1 + port: 8000 + models: + - openai/gpt-oss-20b + weight: 1 + health_check_path: /health + +# Model-specific configuration +model_config: + openai/gpt-oss-20b: + reasoning_family: gpt-oss + preferred_endpoints: + - endpoint1 + pii_policy: + allow_by_default: true + +# Reasoning family configurations +reasoning_families: + deepseek: + type: chat_template_kwargs + parameter: thinking + qwen3: + type: chat_template_kwargs + parameter: enable_thinking + gpt-oss: + type: reasoning_effort + parameter: reasoning_effort + gpt: + type: reasoning_effort + parameter: reasoning_effort + +# Default reasoning effort level +default_reasoning_effort: high + +# Tools configuration (optional) +tools: + enabled: false + top_k: 5 + similarity_threshold: 0.7 + tools_db_path: "config/tools_db.json" + fallback_to_empty: true + +# API configuration +api: + batch_classification: + max_batch_size: 100 + concurrency_threshold: 5 + max_concurrency: 8 + metrics: + enabled: true + detailed_goroutine_tracking: true + high_resolution_timing: false + sample_rate: 1.0 + duration_buckets: + - 0.001 + - 0.005 + - 0.01 + - 0.025 + - 0.05 + - 0.1 + - 0.25 + - 0.5 + - 1 + - 2.5 + - 5 + - 10 + - 30 + size_buckets: + - 1 + - 2 + - 5 + - 10 + - 20 + - 50 + - 100 + - 200 + +# Observability configuration +observability: + tracing: + enabled: false + provider: "opentelemetry" + exporter: + type: "otlp" + endpoint: "localhost:4317" + insecure: true + sampling: + type: "always_on" + resource: + service_name: "semantic-router" + service_version: "1.0.0" + deployment_environment: "production" + diff --git a/examples/mcp-classifier-server/README.md b/examples/mcp-classifier-server/README.md new file mode 100644 index 00000000..f8aaf7f1 --- /dev/null +++ b/examples/mcp-classifier-server/README.md @@ -0,0 +1,134 @@ +# MCP Classification Server + +Example MCP server that provides text classification with intelligent routing for the semantic router. + +## Features + +- **Dynamic Categories**: Loaded from MCP server at runtime via `list_categories` +- **Intelligent Routing**: Returns `model` and `use_reasoning` in classification response +- **Regex-Based**: Simple pattern matching (replace with ML models for production) +- **Dual Transport**: Supports both HTTP and stdio + +## Categories + +| Index | Category | Example Keywords | +|-------|----------|------------------| +| 0 | math | calculate, equation, formula, integral | +| 1 | science | physics, chemistry, biology, atom, DNA | +| 2 | technology | computer, programming, AI, cloud | +| 3 | history | ancient, war, empire, civilization | +| 4 | general | Catch-all for other queries | + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# HTTP mode (for semantic router) +python server.py --http --port 8090 + +# Stdio mode (for MCP clients) +python server.py +``` + +**Test the server:** + +```bash +curl http://localhost:8090/health +# → {"status": "ok", "categories": ["math", "science", "technology", "history", "general"]} +``` + +## Configuration + +**Router config (`config-mcp-classifier-example.yaml`):** + +```yaml +classifier: + category_model: + model_id: "" # Empty = use MCP + + mcp_category_model: + enabled: true + transport_type: "http" + url: "http://localhost:8090/mcp" + # tool_name: optional - auto-discovers classification tool if not specified + threshold: 0.6 + timeout_seconds: 30 + +categories: [] # Loaded dynamically from MCP +default_model: openai/gpt-oss-20b +``` + +**Tool Auto-Discovery:** +The router automatically discovers classification tools from the MCP server by: + +1. Listing available tools on connection +2. Looking for common names: `classify_text`, `classify`, `categorize`, `categorize_text` +3. Pattern matching for tools containing "classif" in name/description +4. Optionally specify `tool_name` to use a specific tool + +## Protocol API + +This server implements the MCP classification protocol defined in: + +``` +github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp/api +``` + +**Required Tools:** + +1. **`list_categories`** - Returns `ListCategoriesResponse`: + + ```json + {"categories": ["math", "science", "technology", ...]} + ``` + +2. **`classify_text`** - Returns `ClassifyResponse`: + + ```json + { + "class": 1, + "confidence": 0.85, + "model": "openai/gpt-oss-20b", + "use_reasoning": true + } + ``` + +See the `api` package for full type definitions and documentation. + +## How It Works + +**Intelligent Routing Rules:** + +- Long query (>20 words) + complex words (`why`, `how`, `explain`) → `use_reasoning: true` +- Math + short query → `use_reasoning: false` +- High confidence (>0.9) → `use_reasoning: false` +- Low confidence (<0.6) → `use_reasoning: true` +- Default → `use_reasoning: true` + +## Customization + +Edit `CATEGORIES` to add categories: + +```python +CATEGORIES = { + "your_category": { + "patterns": [r"\b(keyword1|keyword2)\b"], + "description": "Your description" + } +} +``` + +Edit `decide_routing()` for custom routing logic: + +```python +def decide_routing(text, category, confidence): + if category == "math": + return "deepseek/deepseek-math", False + return "openai/gpt-oss-20b", True +``` + +## License + +MIT diff --git a/examples/mcp-classifier-server/requirements.txt b/examples/mcp-classifier-server/requirements.txt new file mode 100644 index 00000000..ad2fa6d1 --- /dev/null +++ b/examples/mcp-classifier-server/requirements.txt @@ -0,0 +1,2 @@ +mcp>=1.0.0 +aiohttp>=3.9.0 diff --git a/examples/mcp-classifier-server/server.py b/examples/mcp-classifier-server/server.py new file mode 100755 index 00000000..b4bbc4fa --- /dev/null +++ b/examples/mcp-classifier-server/server.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python3 +""" +Regex-Based MCP Classification Server with Intelligent Routing + +This is an example MCP server that demonstrates: +1. Text classification using regex patterns +2. Dynamic category discovery via list_categories +3. Intelligent routing decisions (model selection and reasoning control) + +The server implements two MCP tools: +- 'list_categories': Returns available categories for dynamic loading +- 'classify_text': Classifies text and returns routing recommendations + +Protocol: +- list_categories returns: {"categories": ["math", "science", "technology", ...]} +- classify_text returns: { + "class": 0, + "confidence": 0.85, + "model": "openai/gpt-oss-20b", + "use_reasoning": true, + "probabilities": [...] # optional + } + +Usage: + # Stdio mode (for testing with MCP clients) + python server.py + + # HTTP mode (for semantic router) + python server.py --http --port 8080 +""" + +import argparse +import json +import logging +import math +import re +from typing import Any + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import TextContent, Tool + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Define classification categories and their regex patterns +CATEGORIES = { + "math": { + "patterns": [ + r"\b(calculate|compute|solve|equation|formula|algebra|geometry|trigonometry)\b", + r"\b(integral|derivative|differential|matrix|vector|probability)\b", + r"\b(\d+\s*[+\-*/]\s*\d+|x\s*[=<>]\s*\d+)\b", + r"\b(sin|cos|tan|log|sqrt|sum|average|mean)\b", + ], + "description": "Mathematical and computational queries", + }, + "science": { + "patterns": [ + r"\b(physics|chemistry|biology|astronomy|geology|ecology)\b", + r"\b(atom|molecule|cell|DNA|RNA|evolution|gravity|energy)\b", + r"\b(experiment|hypothesis|theory|scientific|research)\b", + r"\b(planet|star|galaxy|universe|ecosystem|organism)\b", + ], + "description": "Scientific concepts and queries", + }, + "technology": { + "patterns": [ + r"\b(computer|software|hardware|programming|code|algorithm)\b", + r"\b(internet|network|database|server|cloud|API)\b", + r"\b(machine learning|AI|artificial intelligence|neural network)\b", + r"\b(python|java|javascript|C\+\+|golang|rust)\b", + ], + "description": "Technology and computing topics", + }, + "history": { + "patterns": [ + r"\b(ancient|medieval|renaissance|revolution|war|empire)\b", + r"\b(century|era|period|historical|historian|archaeology)\b", + r"\b(civilization|dynasty|monarchy|republic|democracy)\b", + r"\b(BCE|CE|AD|BC|\d{4})\b.*\b(year|century|ago)\b", + ], + "description": "Historical events and topics", + }, + "general": { + "patterns": [r".*"], # Catch-all pattern + "description": "General questions and topics", + }, +} + +# Build index from category names to indices +CATEGORY_NAMES = list(CATEGORIES.keys()) +CATEGORY_TO_INDEX = {name: idx for idx, name in enumerate(CATEGORY_NAMES)} +NUM_CATEGORIES = len(CATEGORY_NAMES) + + +def calculate_confidence(text: str, category_name: str) -> float: + """ + Calculate confidence score for a category based on pattern matches. + + Args: + text: Input text to classify + category_name: Category name to check + + Returns: + Confidence score between 0 and 1 + """ + patterns = CATEGORIES[category_name]["patterns"] + matches = 0 + total_patterns = len(patterns) + + text_lower = text.lower() + + for pattern in patterns: + if re.search(pattern, text_lower, re.IGNORECASE): + matches += 1 + + # Calculate base confidence from match ratio + if total_patterns == 0: + return 0.0 + + match_ratio = matches / total_patterns + + # Boost confidence for multiple matches + confidence = min(0.5 + (match_ratio * 0.5), 1.0) + + return confidence + + +def calculate_entropy(probabilities: list[float]) -> float: + """ + Calculate Shannon entropy of the probability distribution. + + Args: + probabilities: List of probability values + + Returns: + Entropy value + """ + entropy = 0.0 + for p in probabilities: + if p > 0: + entropy -= p * math.log2(p) + return entropy + + +def decide_routing( + text: str, category_name: str, confidence: float +) -> tuple[str, bool]: + """ + Decide which model to use and whether to enable reasoning based on query analysis. + + This is a simple example that demonstrates intelligent routing. In production, + this could use ML models, query complexity analysis, etc. + + Args: + text: Input text being classified + category_name: Predicted category + confidence: Classification confidence + + Returns: + Tuple of (model_name, use_reasoning) + """ + # Analyze query complexity + text_lower = text.lower() + word_count = len(text.split()) + + # Check for complexity indicators + complex_words = [ + "why", + "how", + "explain", + "analyze", + "compare", + "evaluate", + "describe", + ] + has_complex_words = any(word in text_lower for word in complex_words) + + # Simple routing logic (you can make this much more sophisticated) + + # Long queries with complex words → use reasoning + if word_count > 20 and has_complex_words: + return "openai/gpt-oss-20b", True + + # Math category with simple queries → no reasoning needed + if category_name == "math" and word_count < 15: + return "openai/gpt-oss-20b", False + + # High confidence → can use simpler model + if confidence > 0.9: + return "openai/gpt-oss-20b", False + + # Low confidence → use reasoning to be safe + if confidence < 0.6: + return "openai/gpt-oss-20b", True + + # Default: use reasoning for better quality + return "openai/gpt-oss-20b", True + + +def classify_text(text: str, with_probabilities: bool = False) -> dict[str, Any]: + """ + Classify text using regex patterns and return routing recommendations. + + Args: + text: Input text to classify + with_probabilities: Whether to return full probability distribution + + Returns: + Dictionary with classification results including model and use_reasoning + """ + logger.info( + f"Classifying text: '{text[:100]}...' (with_probabilities={with_probabilities})" + ) + + # Calculate confidence for each category + confidences = {} + for category_name in CATEGORY_NAMES: + confidences[category_name] = calculate_confidence(text, category_name) + + # Find the category with highest confidence + best_category = max(confidences.items(), key=lambda x: x[1]) + best_category_name = best_category[0] + best_confidence = best_category[1] + + # Get the class index + class_index = CATEGORY_TO_INDEX[best_category_name] + + # Decide routing (model and reasoning) + model, use_reasoning = decide_routing(text, best_category_name, best_confidence) + + result = { + "class": class_index, + "confidence": best_confidence, + "model": model, + "use_reasoning": use_reasoning, + } + + if with_probabilities: + # Normalize confidences to probabilities + total = sum(confidences.values()) + if total > 0: + probabilities = [confidences[cat] / total for cat in CATEGORY_NAMES] + else: + # Uniform distribution if no matches + probabilities = [1.0 / NUM_CATEGORIES] * NUM_CATEGORIES + + result["probabilities"] = probabilities + + # Calculate and add entropy + entropy_value = calculate_entropy(probabilities) + result["entropy"] = entropy_value + + logger.info( + f"Classification result: class={class_index} ({best_category_name}), " + f"confidence={best_confidence:.3f}, entropy={entropy_value:.3f}, " + f"model={model}, use_reasoning={use_reasoning}" + ) + else: + logger.info( + f"Classification result: class={class_index} ({best_category_name}), " + f"confidence={best_confidence:.3f}, model={model}, use_reasoning={use_reasoning}" + ) + + return result + + +# Initialize MCP server +app = Server("regex-classifier") + + +@app.list_tools() +async def list_tools() -> list[Tool]: + """List available tools.""" + return [ + Tool( + name="classify_text", + description=( + "Classify text into categories and provide intelligent routing recommendations. " + f"Categories: {', '.join(CATEGORY_NAMES)}. " + "Returns: class index, confidence, recommended model, and reasoning flag. " + "Optionally returns full probability distribution for entropy analysis." + ), + inputSchema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "The text to classify"}, + "with_probabilities": { + "type": "boolean", + "description": "Whether to return full probability distribution for entropy analysis", + "default": False, + }, + }, + "required": ["text"], + }, + ), + Tool( + name="list_categories", + description=( + "List all available classification categories. " + "Returns a simple array of category names that the router will use for dynamic category loading." + ), + inputSchema={"type": "object", "properties": {}}, + ), + ] + + +@app.call_tool() +async def call_tool(name: str, arguments: Any) -> list[TextContent]: + """Handle tool calls.""" + if name == "classify_text": + text = arguments.get("text", "") + with_probabilities = arguments.get("with_probabilities", False) + + if not text: + return [ + TextContent(type="text", text=json.dumps({"error": "No text provided"})) + ] + + try: + result = classify_text(text, with_probabilities) + return [TextContent(type="text", text=json.dumps(result))] + except Exception as e: + logger.error(f"Error classifying text: {e}", exc_info=True) + return [TextContent(type="text", text=json.dumps({"error": str(e)}))] + + elif name == "list_categories": + # Return simple list of category names as expected by semantic router + categories_response = {"categories": CATEGORY_NAMES} + logger.info(f"Returning {len(CATEGORY_NAMES)} categories: {CATEGORY_NAMES}") + return [TextContent(type="text", text=json.dumps(categories_response))] + + else: + return [ + TextContent( + type="text", text=json.dumps({"error": f"Unknown tool: {name}"}) + ) + ] + + +async def main_stdio(): + """Run the MCP server in stdio mode.""" + logger.info("Starting Regex-Based MCP Classification Server (stdio mode)") + logger.info(f"Available categories: {', '.join(CATEGORY_NAMES)}") + + async with stdio_server() as (read_stream, write_stream): + await app.run(read_stream, write_stream, app.create_initialization_options()) + + +async def main_http(port: int = 8080): + """Run the MCP server in HTTP mode.""" + try: + from aiohttp import web + except ImportError: + logger.error( + "aiohttp is required for HTTP mode. Install it with: pip install aiohttp" + ) + return + + logger.info(f"Starting Regex-Based MCP Classification Server (HTTP mode)") + logger.info(f"Available categories: {', '.join(CATEGORY_NAMES)}") + logger.info(f"Listening on http://0.0.0.0:{port}/mcp") + + async def handle_mcp_request(request): + """Handle MCP requests over HTTP.""" + try: + # Parse JSON-RPC request or extract method from URL path + data = await request.json() + + # Support both styles: + # 1. JSON-RPC style: {"method": "tools/list", "params": {...}} + # 2. REST style: POST to /mcp/tools/list with direct params + method = data.get("method", "") + + # If no method in JSON, extract from URL path + if not method: + path = request.path + if path.startswith("/mcp/"): + method = path[5:] # Remove '/mcp/' prefix + elif path == "/mcp": + method = "" + + params = data.get("params", data if not data.get("method") else {}) + request_id = data.get("id", 1) + + logger.debug( + f"Received MCP request: method={method}, path={request.path}, id={request_id}" + ) + + # Handle initialize + if method == "initialize": + init_result = { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {}, + }, + "serverInfo": {"name": "regex-classifier", "version": "1.0.0"}, + } + + # For REST-style endpoints, return direct result + # For JSON-RPC style, return full JSON-RPC response + if request.path.startswith("/mcp/") and request.path != "/mcp": + # REST style - return direct result + return web.json_response(init_result) + else: + # JSON-RPC style + result = {"jsonrpc": "2.0", "id": request_id, "result": init_result} + return web.json_response(result) + + # Handle tools/list + elif method == "tools/list": + tools_list = await list_tools() + tools_data = [ + { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.inputSchema, + } + for tool in tools_list + ] + + # For REST-style endpoints (path-based), return direct result + # For JSON-RPC style (method in body), return full JSON-RPC response + if request.path.startswith("/mcp/") and request.path != "/mcp": + # REST style - return direct result + return web.json_response({"tools": tools_data}) + else: + # JSON-RPC style + result = { + "jsonrpc": "2.0", + "id": request_id, + "result": {"tools": tools_data}, + } + return web.json_response(result) + + # Handle tools/call + elif method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + tool_result = await call_tool(tool_name, arguments) + + # Convert TextContent to dict + content = [{"type": tc.type, "text": tc.text} for tc in tool_result] + + result_data = {"content": content, "isError": False} + + # For REST-style endpoints, return direct result + # For JSON-RPC style, return full JSON-RPC response + if request.path.startswith("/mcp/") and request.path != "/mcp": + # REST style - return direct result + return web.json_response(result_data) + else: + # JSON-RPC style + result = {"jsonrpc": "2.0", "id": request_id, "result": result_data} + return web.json_response(result) + + # Handle ping + elif method == "ping": + result = {"jsonrpc": "2.0", "id": request_id, "result": {}} + return web.json_response(result) + + else: + error = { + "jsonrpc": "2.0", + "id": request_id, + "error": {"code": -32601, "message": f"Method not found: {method}"}, + } + return web.json_response(error, status=404) + + except Exception as e: + logger.error(f"Error handling request: {e}", exc_info=True) + error = { + "jsonrpc": "2.0", + "id": request.get("id") if isinstance(request, dict) else None, + "error": {"code": -32603, "message": f"Internal error: {str(e)}"}, + } + return web.json_response(error, status=500) + + async def health_check(request): + """Health check endpoint.""" + return web.json_response({"status": "ok", "categories": CATEGORY_NAMES}) + + # Create web application + http_app = web.Application() + + # Main JSON-RPC endpoint (single endpoint style) + http_app.router.add_post("/mcp", handle_mcp_request) + + # REST-style endpoints (for Go client compatibility) + http_app.router.add_post("/mcp/initialize", handle_mcp_request) + http_app.router.add_post("/mcp/tools/list", handle_mcp_request) + http_app.router.add_post("/mcp/tools/call", handle_mcp_request) + http_app.router.add_post("/mcp/resources/list", handle_mcp_request) + http_app.router.add_post("/mcp/resources/read", handle_mcp_request) + http_app.router.add_post("/mcp/prompts/list", handle_mcp_request) + http_app.router.add_post("/mcp/prompts/get", handle_mcp_request) + http_app.router.add_post("/mcp/ping", handle_mcp_request) + + # Health check + http_app.router.add_get("/health", health_check) + + # Run the server + runner = web.AppRunner(http_app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", port) + await site.start() + + logger.info(f"Server is ready at http://0.0.0.0:{port}/mcp") + logger.info(f"Health check available at http://0.0.0.0:{port}/health") + + # Keep the server running + try: + while True: + await asyncio.sleep(3600) + except KeyboardInterrupt: + logger.info("Shutting down server...") + finally: + await runner.cleanup() + + +if __name__ == "__main__": + import asyncio + + # Parse command line arguments + parser = argparse.ArgumentParser(description="MCP Classification Server") + parser.add_argument( + "--http", action="store_true", help="Run in HTTP mode instead of stdio" + ) + parser.add_argument("--port", type=int, default=8090, help="HTTP port to listen on") + args = parser.parse_args() + + if args.http: + asyncio.run(main_http(args.port)) + else: + asyncio.run(main_stdio()) diff --git a/src/semantic-router/go.mod b/src/semantic-router/go.mod index a12034d2..f96c8109 100644 --- a/src/semantic-router/go.mod +++ b/src/semantic-router/go.mod @@ -14,6 +14,7 @@ replace ( require ( github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/fsnotify/fsnotify v1.7.0 + github.com/mark3labs/mcp-go v0.42.0-beta.1 github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/gomega v1.38.0 @@ -34,7 +35,9 @@ require ( ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect @@ -56,10 +59,12 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -70,11 +75,14 @@ require ( github.com/prometheus/common v0.65.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect diff --git a/src/semantic-router/go.sum b/src/semantic-router/go.sum index af77d0b8..e5e53e8a 100644 --- a/src/semantic-router/go.sum +++ b/src/semantic-router/go.sum @@ -8,8 +8,12 @@ github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqR github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -53,6 +57,8 @@ github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2T github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= @@ -137,11 +143,14 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/hydrogen18/memlistener v0.0.0-20200120041712-dcc25e7acd91/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI= github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0= github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -175,6 +184,10 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.42.0-beta.1 h1:jXCUOg7vHwSuknzy4hPvOXASnzmLluM3AMx1rPh/OYM= +github.com/mark3labs/mcp-go v0.42.0-beta.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= @@ -254,6 +267,8 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1 github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= @@ -291,6 +306,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -298,6 +315,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1: github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 6720c6a0..ad2e9220 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -28,6 +28,17 @@ type RouterConfig struct { UseModernBERT bool `yaml:"use_modernbert"` CategoryMappingPath string `yaml:"category_mapping_path"` } `yaml:"category_model"` + MCPCategoryModel struct { + Enabled bool `yaml:"enabled"` + TransportType string `yaml:"transport_type"` + Command string `yaml:"command,omitempty"` + Args []string `yaml:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty"` + URL string `yaml:"url,omitempty"` + ToolName string `yaml:"tool_name,omitempty"` // Optional: will auto-discover if not specified + Threshold float32 `yaml:"threshold"` + TimeoutSeconds int `yaml:"timeout_seconds,omitempty"` + } `yaml:"mcp_category_model,omitempty"` PIIModel struct { ModelID string `yaml:"model_id"` Threshold float32 `yaml:"threshold"` @@ -574,6 +585,11 @@ func (c *RouterConfig) IsCategoryClassifierEnabled() bool { return c.Classifier.CategoryModel.ModelID != "" && c.Classifier.CategoryModel.CategoryMappingPath != "" } +// IsMCPCategoryClassifierEnabled checks if MCP-based category classification is enabled +func (c *RouterConfig) IsMCPCategoryClassifierEnabled() bool { + return c.Classifier.MCPCategoryModel.Enabled && c.Classifier.MCPCategoryModel.ToolName != "" +} + // GetPromptGuardConfig returns the prompt guard configuration func (c *RouterConfig) GetPromptGuardConfig() PromptGuardConfig { return c.PromptGuard diff --git a/src/semantic-router/pkg/config/validation_test.go b/src/semantic-router/pkg/config/validation_test.go index 8fa758e3..1189c054 100644 --- a/src/semantic-router/pkg/config/validation_test.go +++ b/src/semantic-router/pkg/config/validation_test.go @@ -273,3 +273,240 @@ var _ = Describe("IP Address Validation", func() { }) }) }) + +var _ = Describe("MCP Configuration Validation", func() { + Describe("IsMCPCategoryClassifierEnabled", func() { + var cfg *RouterConfig + + BeforeEach(func() { + cfg = &RouterConfig{} + }) + + Context("when MCP is fully configured", func() { + It("should return true", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeTrue()) + }) + }) + + Context("when MCP is not enabled", func() { + It("should return false", func() { + cfg.Classifier.MCPCategoryModel.Enabled = false + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeFalse()) + }) + }) + + Context("when MCP tool name is empty", func() { + It("should return false", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "" + + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeFalse()) + }) + }) + + Context("when both enabled and tool name are missing", func() { + It("should return false", func() { + cfg.Classifier.MCPCategoryModel.Enabled = false + cfg.Classifier.MCPCategoryModel.ToolName = "" + + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeFalse()) + }) + }) + }) + + Describe("MCP Configuration Structure", func() { + var cfg *RouterConfig + + BeforeEach(func() { + cfg = &RouterConfig{} + }) + + Context("when configuring stdio transport", func() { + It("should accept valid stdio configuration", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.TransportType = "stdio" + cfg.Classifier.MCPCategoryModel.Command = "python" + cfg.Classifier.MCPCategoryModel.Args = []string{"server.py"} + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + cfg.Classifier.MCPCategoryModel.Threshold = 0.5 + cfg.Classifier.MCPCategoryModel.TimeoutSeconds = 30 + + Expect(cfg.Classifier.MCPCategoryModel.Enabled).To(BeTrue()) + Expect(cfg.Classifier.MCPCategoryModel.TransportType).To(Equal("stdio")) + Expect(cfg.Classifier.MCPCategoryModel.Command).To(Equal("python")) + Expect(cfg.Classifier.MCPCategoryModel.Args).To(HaveLen(1)) + Expect(cfg.Classifier.MCPCategoryModel.ToolName).To(Equal("classify_text")) + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("==", 0.5)) + Expect(cfg.Classifier.MCPCategoryModel.TimeoutSeconds).To(Equal(30)) + }) + + It("should accept environment variables", func() { + cfg.Classifier.MCPCategoryModel.Env = map[string]string{ + "PYTHONPATH": "/app/lib", + "LOG_LEVEL": "debug", + } + + Expect(cfg.Classifier.MCPCategoryModel.Env).To(HaveLen(2)) + Expect(cfg.Classifier.MCPCategoryModel.Env["PYTHONPATH"]).To(Equal("/app/lib")) + Expect(cfg.Classifier.MCPCategoryModel.Env["LOG_LEVEL"]).To(Equal("debug")) + }) + }) + + Context("when configuring HTTP transport", func() { + It("should accept valid HTTP configuration", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.TransportType = "http" + cfg.Classifier.MCPCategoryModel.URL = "http://localhost:8080/mcp" + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.Classifier.MCPCategoryModel.TransportType).To(Equal("http")) + Expect(cfg.Classifier.MCPCategoryModel.URL).To(Equal("http://localhost:8080/mcp")) + }) + }) + + Context("when threshold is not set", func() { + It("should default to zero", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("==", 0.0)) + }) + }) + + Context("when configuring custom threshold", func() { + It("should accept threshold values between 0 and 1", func() { + testCases := []float32{0.0, 0.3, 0.5, 0.7, 0.9, 1.0} + + for _, threshold := range testCases { + cfg.Classifier.MCPCategoryModel.Threshold = threshold + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("==", threshold)) + } + }) + }) + + Context("when timeout is not set", func() { + It("should default to zero", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.Classifier.MCPCategoryModel.TimeoutSeconds).To(Equal(0)) + }) + }) + }) + + Describe("MCP vs In-tree Classifier Priority", func() { + var cfg *RouterConfig + + BeforeEach(func() { + cfg = &RouterConfig{} + }) + + Context("when both in-tree and MCP are configured", func() { + It("should have both configurations available", func() { + // Configure in-tree classifier + cfg.Classifier.CategoryModel.ModelID = "/path/to/model" + cfg.Classifier.CategoryModel.CategoryMappingPath = "/path/to/mapping.json" + cfg.Classifier.CategoryModel.Threshold = 0.7 + + // Configure MCP classifier + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + cfg.Classifier.MCPCategoryModel.Threshold = 0.5 + + // Both should be configured + Expect(cfg.Classifier.CategoryModel.ModelID).ToNot(BeEmpty()) + Expect(cfg.Classifier.MCPCategoryModel.Enabled).To(BeTrue()) + }) + }) + + Context("when only in-tree is configured", func() { + It("should not have MCP enabled", func() { + cfg.Classifier.CategoryModel.ModelID = "/path/to/model" + cfg.Classifier.CategoryModel.CategoryMappingPath = "/path/to/mapping.json" + + Expect(cfg.Classifier.CategoryModel.ModelID).ToNot(BeEmpty()) + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeFalse()) + }) + }) + + Context("when only MCP is configured", func() { + It("should have MCP enabled and no in-tree model", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeTrue()) + Expect(cfg.Classifier.CategoryModel.ModelID).To(BeEmpty()) + }) + }) + + Context("when neither is configured", func() { + It("should have neither enabled", func() { + Expect(cfg.Classifier.CategoryModel.ModelID).To(BeEmpty()) + Expect(cfg.IsMCPCategoryClassifierEnabled()).To(BeFalse()) + }) + }) + }) + + Describe("MCP Configuration Fields", func() { + var cfg *RouterConfig + + BeforeEach(func() { + cfg = &RouterConfig{} + }) + + It("should support all required fields for stdio transport", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.TransportType = "stdio" + cfg.Classifier.MCPCategoryModel.Command = "python3" + cfg.Classifier.MCPCategoryModel.Args = []string{"-m", "server"} + cfg.Classifier.MCPCategoryModel.Env = map[string]string{"DEBUG": "1"} + cfg.Classifier.MCPCategoryModel.ToolName = "classify" + cfg.Classifier.MCPCategoryModel.Threshold = 0.6 + cfg.Classifier.MCPCategoryModel.TimeoutSeconds = 60 + + Expect(cfg.Classifier.MCPCategoryModel.Enabled).To(BeTrue()) + Expect(cfg.Classifier.MCPCategoryModel.TransportType).To(Equal("stdio")) + Expect(cfg.Classifier.MCPCategoryModel.Command).To(Equal("python3")) + Expect(cfg.Classifier.MCPCategoryModel.Args).To(Equal([]string{"-m", "server"})) + Expect(cfg.Classifier.MCPCategoryModel.Env).To(HaveKeyWithValue("DEBUG", "1")) + Expect(cfg.Classifier.MCPCategoryModel.ToolName).To(Equal("classify")) + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("~", 0.6, 0.01)) + Expect(cfg.Classifier.MCPCategoryModel.TimeoutSeconds).To(Equal(60)) + }) + + It("should support all required fields for HTTP transport", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.TransportType = "http" + cfg.Classifier.MCPCategoryModel.URL = "https://mcp-server:443/api" + cfg.Classifier.MCPCategoryModel.ToolName = "classify" + cfg.Classifier.MCPCategoryModel.Threshold = 0.8 + cfg.Classifier.MCPCategoryModel.TimeoutSeconds = 120 + + Expect(cfg.Classifier.MCPCategoryModel.Enabled).To(BeTrue()) + Expect(cfg.Classifier.MCPCategoryModel.TransportType).To(Equal("http")) + Expect(cfg.Classifier.MCPCategoryModel.URL).To(Equal("https://mcp-server:443/api")) + Expect(cfg.Classifier.MCPCategoryModel.ToolName).To(Equal("classify")) + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("~", 0.8, 0.01)) + Expect(cfg.Classifier.MCPCategoryModel.TimeoutSeconds).To(Equal(120)) + }) + + It("should allow optional fields to be omitted", func() { + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.TransportType = "stdio" + cfg.Classifier.MCPCategoryModel.Command = "server" + cfg.Classifier.MCPCategoryModel.ToolName = "classify" + + // Optional fields should have zero values + Expect(cfg.Classifier.MCPCategoryModel.Args).To(BeNil()) + Expect(cfg.Classifier.MCPCategoryModel.Env).To(BeNil()) + Expect(cfg.Classifier.MCPCategoryModel.URL).To(BeEmpty()) + Expect(cfg.Classifier.MCPCategoryModel.Threshold).To(BeNumerically("==", 0.0)) + Expect(cfg.Classifier.MCPCategoryModel.TimeoutSeconds).To(Equal(0)) + }) + }) +}) diff --git a/src/semantic-router/pkg/connectivity/mcp/api/types.go b/src/semantic-router/pkg/connectivity/mcp/api/types.go new file mode 100644 index 00000000..94b468ab --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/api/types.go @@ -0,0 +1,101 @@ +// Package api defines the MCP protocol contract for semantic router classification. +// +// This package provides strongly-typed definitions for the JSON messages exchanged between +// the Go client and MCP classification servers. These types ensure consistency and can be +// used by both client and server implementations. +// +// Protocol Version: 1.0 +// +// For Python MCP servers, use these JSON formats when implementing classification tools. +// For other language implementations, these types can serve as reference for the expected format. +package api + +// ClassifyRequest represents the arguments for the classify_text tool call +type ClassifyRequest struct { + Text string `json:"text"` // Text to classify + WithProbabilities bool `json:"with_probabilities,omitempty"` // Request full probability distribution +} + +// ClassifyResponse represents the response from the classify_text MCP tool. +// +// This is the core classification response format. The MCP server returns both classification +// results (class index and confidence) and optional routing information (model and use_reasoning). +// +// Example JSON: +// +// { +// "class": 3, +// "confidence": 0.85, +// "model": "openai/gpt-oss-20b", +// "use_reasoning": true +// } +type ClassifyResponse struct { + // Class is the 0-based index of the predicted category + Class int `json:"class"` + + // Confidence is the prediction confidence, ranging from 0.0 to 1.0 + Confidence float32 `json:"confidence"` + + // Model is the recommended model for routing this request (optional). + // If provided, the router will use this model instead of the default_model. + // Example: "openai/gpt-oss-20b", "anthropic/claude-3-opus" + Model string `json:"model,omitempty"` + + // UseReasoning indicates whether to enable reasoning mode for this request (optional). + // If nil/omitted, the router uses its default reasoning configuration. + // If true, enables reasoning mode. If false, disables reasoning mode. + UseReasoning *bool `json:"use_reasoning,omitempty"` +} + +// ClassifyWithProbabilitiesResponse extends ClassifyResponse with full probability distribution. +// +// This format is used when the client requests detailed classification probabilities for all categories. +// The MCP server should return this format when the classify_text tool is called with +// "with_probabilities": true. +// +// Example JSON: +// +// { +// "class": 3, +// "confidence": 0.85, +// "probabilities": [0.05, 0.03, 0.07, 0.85, ...], +// "model": "openai/gpt-oss-20b", +// "use_reasoning": true +// } +type ClassifyWithProbabilitiesResponse struct { + // Class is the 0-based index of the predicted category + Class int `json:"class"` + + // Confidence is the prediction confidence, ranging from 0.0 to 1.0 + Confidence float32 `json:"confidence"` + + // Probabilities is the full probability distribution across all categories. + // The array length must match the number of categories, and values should sum to ~1.0. + Probabilities []float32 `json:"probabilities"` + + // Model is the recommended model for routing this request (optional) + Model string `json:"model,omitempty"` + + // UseReasoning indicates whether to enable reasoning mode for this request (optional) + UseReasoning *bool `json:"use_reasoning,omitempty"` +} + +// ListCategoriesResponse represents the response from the list_categories MCP tool. +// +// This format is used to retrieve the taxonomy of available categories from the MCP server. +// The client calls this tool during initialization to discover which categories the server supports. +// +// Example JSON: +// +// { +// "categories": ["business", "law", "medical", "technical", "general"] +// } +type ListCategoriesResponse struct { + // Categories is the ordered list of category names. + // The index position in this array corresponds to the "class" index in ClassifyResponse. + // For example, if Categories = ["business", "law", "medical"], then: + // - class 0 = "business" + // - class 1 = "law" + // - class 2 = "medical" + Categories []string `json:"categories"` +} diff --git a/src/semantic-router/pkg/connectivity/mcp/factory.go b/src/semantic-router/pkg/connectivity/mcp/factory.go new file mode 100644 index 00000000..f120a9d4 --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/factory.go @@ -0,0 +1,171 @@ +package mcp + +import ( + "fmt" + "log" +) + +// ClientFactory creates MCP clients based on configuration +type ClientFactory struct{} + +// NewClientFactory creates a new client factory +func NewClientFactory() *ClientFactory { + return &ClientFactory{} +} + +// CreateClient creates an MCP client based on the configuration +func (f *ClientFactory) CreateClient(name string, config ClientConfig) (MCPClient, error) { + transportType := f.determineTransportType(config) + + switch transportType { + case string(TransportStdio): + return NewStdioClient(name, config), nil + case string(TransportStreamableHTTP), "http": + return NewHTTPClient(name, config), nil + default: + return nil, fmt.Errorf("unsupported transport type: %s", transportType) + } +} + +// determineTransportType determines the transport type from configuration +func (f *ClientFactory) determineTransportType(config ClientConfig) string { + if config.TransportType != "" { + return config.TransportType + } + + if config.Command != "" { + return string(TransportStdio) + } + + if config.URL != "" { + return string(TransportStreamableHTTP) + } + + return string(TransportStdio) +} + +// CreateClientWithDefaults creates a client with default configuration +func (f *ClientFactory) CreateClientWithDefaults(name string, transportType TransportType) (MCPClient, error) { + config := ClientConfig{ + TransportType: string(transportType), + Options: ClientOptions{ + LogEnabled: true, + }, + } + + return f.CreateClient(name, config) +} + +// NewClient is a convenience function that creates a client using the factory +func NewClient(name string, config ClientConfig) (MCPClient, error) { + factory := NewClientFactory() + return factory.CreateClient(name, config) +} + +// NewClientFromConfig creates a client from configuration with validation +func NewClientFromConfig(name string, config ClientConfig) (MCPClient, error) { + if err := validateConfig(config); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return NewClient(name, config) +} + +// validateConfig validates the client configuration +func validateConfig(config ClientConfig) error { + transportType := config.TransportType + if transportType == "" { + // Auto-detect transport type + if config.Command == "" && config.URL == "" { + return fmt.Errorf("either command or URL must be specified") + } + return nil + } + + switch TransportType(transportType) { + case TransportStdio: + if config.Command == "" { + return fmt.Errorf("command is required for stdio transport") + } + case TransportStreamableHTTP: + if config.URL == "" { + return fmt.Errorf("URL is required for %s transport", transportType) + } + default: + // Also accept "http" as an alias for "streamable-http" + if transportType == "http" { + if config.URL == "" { + return fmt.Errorf("URL is required for http transport") + } + } else { + return fmt.Errorf("unsupported transport type: %s", transportType) + } + } + + return nil +} + +// CreateMultipleClients creates multiple clients from a map of configurations +func CreateMultipleClients(configs map[string]ClientConfig) (map[string]MCPClient, error) { + factory := NewClientFactory() + clients := make(map[string]MCPClient) + + for name, config := range configs { + client, err := factory.CreateClient(name, config) + if err != nil { + // Close any already created clients on error + for _, c := range clients { + c.Close() + } + return nil, fmt.Errorf("failed to create client '%s': %w", name, err) + } + clients[name] = client + } + + return clients, nil +} + +// LoggingClientWrapper wraps any client with enhanced logging +type LoggingClientWrapper struct { + MCPClient + name string +} + +// NewLoggingClientWrapper creates a new logging wrapper +func NewLoggingClientWrapper(client MCPClient, name string) *LoggingClientWrapper { + wrapper := &LoggingClientWrapper{ + MCPClient: client, + name: name, + } + + // Set up enhanced logging + client.SetLogHandler(func(level LoggingLevel, message string) { + log.Printf("[%s] %s: %s", level, name, message) + }) + + return wrapper +} + +// Connect wraps the connection with additional logging +func (w *LoggingClientWrapper) Connect() error { + log.Printf("Connecting client: %s", w.name) + err := w.MCPClient.Connect() + if err != nil { + log.Printf("Failed to connect client %s: %v", w.name, err) + } else { + log.Printf("Successfully connected client: %s", w.name) + } + return err +} + +// Close wraps the close with additional logging +func (w *LoggingClientWrapper) Close() error { + log.Printf("Closing client: %s", w.name) + err := w.MCPClient.Close() + if err != nil { + log.Printf("Error closing client %s: %v", w.name, err) + } else { + log.Printf("Successfully closed client: %s", w.name) + } + return err +} diff --git a/src/semantic-router/pkg/connectivity/mcp/http_client.go b/src/semantic-router/pkg/connectivity/mcp/http_client.go new file mode 100644 index 00000000..432df544 --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/http_client.go @@ -0,0 +1,388 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + // MCPProtocolVersion is the MCP protocol version supported by this implementation. + // This version is defined by the Model Context Protocol specification and must match + // the version expected by MCP servers. The version "2024-11-05" represents the current + // stable MCP specification release. This is not derived from mcp-go as the library + // does not export a version constant; instead, it's specified by the protocol itself. + MCPProtocolVersion = "2024-11-05" + // MCPClientVersion is the version of this MCP client implementation + MCPClientVersion = "1.0.0" +) + +// HTTPClient implements MCPClient for streamable HTTP transport +type HTTPClient struct { + *BaseClient + httpClient *http.Client + baseURL string +} + +// NewHTTPClient creates a new HTTP MCP client +func NewHTTPClient(name string, config ClientConfig) *HTTPClient { + baseClient := NewBaseClient(name, config) + return &HTTPClient{ + BaseClient: baseClient, + httpClient: &http.Client{ + Timeout: config.Timeout, + }, + baseURL: config.URL, + } +} + +// Connect establishes connection to the MCP server via HTTP +func (c *HTTPClient) Connect() error { + if c.connected { + return nil + } + + if c.config.URL == "" { + return fmt.Errorf("URL is required for HTTP transport") + } + + c.log(LoggingLevelInfo, fmt.Sprintf("Connecting to HTTP endpoint: %s", c.config.URL)) + + // Parse the URL to validate it + parsedURL, err := url.Parse(c.config.URL) + if err != nil { + return fmt.Errorf("invalid HTTP URL: %w", err) + } + + // Validate that it's an HTTP/HTTPS URL + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("HTTP URL must use http or https scheme") + } + + c.baseURL = c.config.URL + + // Test connectivity with a simple health check or initialize call + if err := c.testConnection(); err != nil { + return fmt.Errorf("failed to connect to HTTP endpoint: %w", err) + } + + c.connected = true + c.log(LoggingLevelInfo, "Successfully connected to HTTP endpoint") + + // Initialize capabilities + if err := c.initializeCapabilities(); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to initialize capabilities: %v", err)) + } + + return nil +} + +// testConnection tests the HTTP connection +func (c *HTTPClient) testConnection() error { + ctx := context.Background() + if c.config.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.config.Timeout) + defer cancel() + } + + // Try to initialize the connection + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = MCPProtocolVersion + initRequest.Params.Capabilities = mcp.ClientCapabilities{ + Roots: &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{}, + } + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "http-mcp-client", + Version: MCPClientVersion, + } + + _, err := c.sendRequest(ctx, "initialize", initRequest) + return err +} + +// initializeCapabilities initializes capabilities for HTTP client +func (c *HTTPClient) initializeCapabilities() error { + ctx := context.Background() + if c.config.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.config.Timeout) + defer cancel() + } + + // Load tools + if err := c.loadTools(ctx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load tools: %v", err)) + } + + // Load resources + if err := c.loadResources(ctx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load resources: %v", err)) + } + + // Load prompts + if err := c.loadPrompts(ctx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load prompts: %v", err)) + } + + return nil +} + +// loadTools loads available tools from the HTTP MCP server +func (c *HTTPClient) loadTools(ctx context.Context) error { + request := mcp.ListToolsRequest{} + + response, err := c.sendRequest(ctx, "tools/list", request) + if err != nil { + return fmt.Errorf("failed to load tools: %w", err) + } + + var toolsResult mcp.ListToolsResult + if err := json.Unmarshal(response, &toolsResult); err != nil { + return fmt.Errorf("failed to parse tools response: %w", err) + } + + // Apply tool filtering + filteredTools := FilterTools(toolsResult.Tools, c.config.Options.ToolFilter) + + c.tools = filteredTools + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d tools (filtered from %d)", len(c.tools), len(toolsResult.Tools))) + + return nil +} + +// loadResources loads available resources from the HTTP MCP server +func (c *HTTPClient) loadResources(ctx context.Context) error { + request := mcp.ListResourcesRequest{} + + response, err := c.sendRequest(ctx, "resources/list", request) + if err != nil { + return fmt.Errorf("failed to load resources: %w", err) + } + + var resourcesResult mcp.ListResourcesResult + if err := json.Unmarshal(response, &resourcesResult); err != nil { + return fmt.Errorf("failed to parse resources response: %w", err) + } + + c.resources = resourcesResult.Resources + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d resources", len(c.resources))) + + return nil +} + +// loadPrompts loads available prompts from the HTTP MCP server +func (c *HTTPClient) loadPrompts(ctx context.Context) error { + request := mcp.ListPromptsRequest{} + + response, err := c.sendRequest(ctx, "prompts/list", request) + if err != nil { + return fmt.Errorf("failed to load prompts: %w", err) + } + + var promptsResult mcp.ListPromptsResult + if err := json.Unmarshal(response, &promptsResult); err != nil { + return fmt.Errorf("failed to parse prompts response: %w", err) + } + + c.prompts = promptsResult.Prompts + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d prompts", len(c.prompts))) + + return nil +} + +// CallTool calls a tool on the MCP server via HTTP +func (c *HTTPClient) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + if !c.connected { + return nil, fmt.Errorf("client not connected") + } + + // Check if tool exists and is allowed + var toolFound bool + for _, tool := range c.tools { + if tool.Name == name { + toolFound = true + break + } + } + + if !toolFound { + return nil, fmt.Errorf("tool '%s' not found or not allowed", name) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Calling tool via HTTP: %s", name)) + + request := mcp.CallToolRequest{} + request.Params.Name = name + request.Params.Arguments = arguments + + response, err := c.sendRequest(ctx, "tools/call", request) + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Tool call failed: %v", err)) + return nil, fmt.Errorf("tool call failed: %w", err) + } + + var result mcp.CallToolResult + if err := json.Unmarshal(response, &result); err != nil { + return nil, fmt.Errorf("failed to parse tool call response: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Tool call successful via HTTP: %s", name)) + return &result, nil +} + +// ReadResource reads a resource from the MCP server via HTTP +func (c *HTTPClient) ReadResource(ctx context.Context, uri string) (*mcp.ReadResourceResult, error) { + if !c.connected { + return nil, fmt.Errorf("client not connected") + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Reading resource via HTTP: %s", uri)) + + request := mcp.ReadResourceRequest{} + request.Params.URI = uri + + response, err := c.sendRequest(ctx, "resources/read", request) + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Resource read failed: %v", err)) + return nil, fmt.Errorf("resource read failed: %w", err) + } + + var result mcp.ReadResourceResult + if err := json.Unmarshal(response, &result); err != nil { + return nil, fmt.Errorf("failed to parse resource response: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Resource read successful via HTTP: %s", uri)) + return &result, nil +} + +// GetPrompt gets a prompt from the MCP server via HTTP +func (c *HTTPClient) GetPrompt(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.GetPromptResult, error) { + if !c.connected { + return nil, fmt.Errorf("client not connected") + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Getting prompt via HTTP: %s", name)) + + request := mcp.GetPromptRequest{} + request.Params.Name = name + request.Params.Arguments = convertArgsToStringMap(arguments) + + response, err := c.sendRequest(ctx, "prompts/get", request) + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Prompt get failed: %v", err)) + return nil, fmt.Errorf("prompt get failed: %w", err) + } + + var result mcp.GetPromptResult + if err := json.Unmarshal(response, &result); err != nil { + return nil, fmt.Errorf("failed to parse prompt response: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Prompt get successful via HTTP: %s", name)) + return &result, nil +} + +// Ping sends a ping to the MCP server via HTTP +func (c *HTTPClient) Ping(ctx context.Context) error { + if !c.connected { + return fmt.Errorf("client not connected") + } + + _, err := c.sendRequest(ctx, "ping", nil) + if err != nil { + return fmt.Errorf("ping failed: %w", err) + } + + c.log(LoggingLevelDebug, "Ping successful via HTTP") + return nil +} + +// RefreshCapabilities reloads tools, resources, and prompts +func (c *HTTPClient) RefreshCapabilities(ctx context.Context) error { + if !c.connected { + return fmt.Errorf("client not connected") + } + + return c.initializeCapabilities() +} + +// sendRequest sends an HTTP request to the MCP server +func (c *HTTPClient) sendRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) { + // Construct URL + requestURL := fmt.Sprintf("%s/%s", c.baseURL, endpoint) + + // Prepare request body + var body io.Reader + if payload != nil { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + body = bytes.NewReader(jsonData) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", requestURL, body) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // Add custom headers if provided + for key, value := range c.config.Headers { + req.Header.Set(key, value) + } + + // Make the request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + // Read response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Check status code + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("HTTP request failed with status %d: %s", resp.StatusCode, string(responseBody)) + } + + return responseBody, nil +} + +// convertArgsToStringMap converts arguments to string map for compatibility +func convertArgsToStringMap(arguments map[string]interface{}) map[string]string { + stringArgs := make(map[string]string) + for k, v := range arguments { + if str, ok := v.(string); ok { + stringArgs[k] = str + } else { + stringArgs[k] = fmt.Sprintf("%v", v) + } + } + return stringArgs +} + +// Close closes the connection to the MCP server +func (c *HTTPClient) Close() error { + c.connected = false + c.log(LoggingLevelInfo, "Disconnected from HTTP MCP server") + return nil +} diff --git a/src/semantic-router/pkg/connectivity/mcp/interface.go b/src/semantic-router/pkg/connectivity/mcp/interface.go new file mode 100644 index 00000000..112a797c --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/interface.go @@ -0,0 +1,126 @@ +package mcp + +import ( + "context" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient defines the interface that all MCP client implementations must satisfy +type MCPClient interface { + // Connection management + Connect() error + Close() error + IsConnected() bool + Ping(ctx context.Context) error + + // Capability management + GetTools() []mcp.Tool + GetResources() []mcp.Resource + GetPrompts() []mcp.Prompt + RefreshCapabilities(ctx context.Context) error + + // Tool operations + CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) + + // Resource operations + ReadResource(ctx context.Context, uri string) (*mcp.ReadResourceResult, error) + + // Prompt operations + GetPrompt(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.GetPromptResult, error) + + // Logging + SetLogHandler(handler func(LoggingLevel, string)) +} + +// BaseClient provides common functionality for all client implementations +type BaseClient struct { + name string + config ClientConfig + tools []mcp.Tool + resources []mcp.Resource + prompts []mcp.Prompt + logHandler func(LoggingLevel, string) + connected bool +} + +// NewBaseClient creates a new base client +func NewBaseClient(name string, config ClientConfig) *BaseClient { + return &BaseClient{ + name: name, + config: config, + connected: false, + logHandler: func(level LoggingLevel, message string) { + // Default log handler - can be overridden + }, + } +} + +// GetTools returns the available tools +func (c *BaseClient) GetTools() []mcp.Tool { + return c.tools +} + +// GetResources returns the available resources +func (c *BaseClient) GetResources() []mcp.Resource { + return c.resources +} + +// GetPrompts returns the available prompts +func (c *BaseClient) GetPrompts() []mcp.Prompt { + return c.prompts +} + +// IsConnected returns whether the client is connected +func (c *BaseClient) IsConnected() bool { + return c.connected +} + +// SetLogHandler sets the log handler function +func (c *BaseClient) SetLogHandler(handler func(LoggingLevel, string)) { + c.logHandler = handler +} + +// log writes a log message using the configured handler +func (c *BaseClient) log(level LoggingLevel, message string) { + if c.logHandler != nil { + c.logHandler(level, message) + } +} + +// ClientConfig represents client configuration +type ClientConfig struct { + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + URL string `json:"url,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + TransportType string `json:"transportType,omitempty"` + Timeout time.Duration `json:"timeout,omitempty"` + Options ClientOptions `json:"options"` +} + +// ToolFilter represents tool filtering configuration +type ToolFilter struct { + Mode string `json:"mode"` // "allow" or "block" + List []string `json:"list"` +} + +// ClientOptions represents client options +type ClientOptions struct { + PanicIfInvalid bool `json:"panicIfInvalid"` + LogEnabled bool `json:"logEnabled"` + AuthTokens []string `json:"authTokens"` + ToolFilter ToolFilter `json:"toolFilter"` +} + +// TransportType represents the transport type for MCP communication. +// Supported values: "stdio" for stdin/stdout and "streamable-http" for HTTP transport. +// Note: "http" is also accepted as an alias for "streamable-http" for convenience. +type TransportType string + +const ( + TransportStdio TransportType = "stdio" + TransportStreamableHTTP TransportType = "streamable-http" +) diff --git a/src/semantic-router/pkg/connectivity/mcp/stdio_client.go b/src/semantic-router/pkg/connectivity/mcp/stdio_client.go new file mode 100644 index 00000000..0f38ad38 --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/stdio_client.go @@ -0,0 +1,336 @@ +package mcp + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioClient implements MCPClient for stdio transport +type StdioClient struct { + *BaseClient + mcpClient client.MCPClient +} + +// NewStdioClient creates a new stdio MCP client +func NewStdioClient(name string, config ClientConfig) *StdioClient { + baseClient := NewBaseClient(name, config) + return &StdioClient{ + BaseClient: baseClient, + } +} + +// Connect establishes connection to the MCP server via stdio +func (c *StdioClient) Connect() error { + if c.connected { + return nil + } + + c.log(LoggingLevelInfo, fmt.Sprintf("Connecting to MCP server with stdio transport")) + + if c.config.Command == "" { + return fmt.Errorf("command is required for stdio transport") + } + + // Prepare environment variables + env := os.Environ() + for key, value := range c.config.Env { + env = append(env, fmt.Sprintf("%s=%s", key, value)) + } + + c.log(LoggingLevelInfo, fmt.Sprintf("Starting MCP server: %s %v", c.config.Command, c.config.Args)) + c.log(LoggingLevelDebug, fmt.Sprintf("Environment variables: %v", c.config.Env)) + + // Create MCP client using the correct API + mcpClient, err := client.NewStdioMCPClient(c.config.Command, env, c.config.Args...) + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Failed to create MCP client: %v", err)) + return fmt.Errorf("failed to create stdio MCP client: %w", err) + } + + c.mcpClient = mcpClient + + // Initialize the connection with proper context handling + ctx := context.Background() + timeout := 30 * time.Second // default timeout + if c.config.Timeout > 0 { + timeout = c.config.Timeout + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Use a channel to handle the initialization asynchronously + initDone := make(chan error, 1) + + go func() { + defer func() { + if r := recover(); r != nil { + c.log(LoggingLevelError, fmt.Sprintf("Panic during MCP initialization: %v", r)) + initDone <- fmt.Errorf("panic during initialization: %v", r) + } + }() + + c.log(LoggingLevelDebug, fmt.Sprintf("Starting MCP client...")) + // Note: Connection starts automatically; Start() method is not required + + c.log(LoggingLevelDebug, fmt.Sprintf("Initializing MCP client...")) + // Initialize the client with a simple request (no params needed for basic initialization) + initResult, err := c.mcpClient.Initialize(ctx, mcp.InitializeRequest{}) + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Failed to initialize MCP client: %v", err)) + initDone <- fmt.Errorf("failed to initialize MCP client: %w", err) + return + } + + c.log(LoggingLevelInfo, fmt.Sprintf("Initialized MCP client: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version)) + initDone <- nil + }() + + // Wait for initialization to complete or timeout + select { + case err := <-initDone: + if err != nil { + c.connected = false + return err + } + c.connected = true + case <-ctx.Done(): + c.connected = false + return fmt.Errorf("timeout connecting to MCP server after %v", timeout) + } + + // Load available capabilities in background + go func() { + capCtx, capCancel := context.WithTimeout(context.Background(), timeout) + defer capCancel() + + if err := c.loadCapabilities(capCtx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load capabilities: %v", err)) + } + }() + + return nil +} + +// loadCapabilities loads tools, resources, and prompts from the MCP server +func (c *StdioClient) loadCapabilities(ctx context.Context) error { + // Load tools + if err := c.loadTools(ctx); err != nil { + return fmt.Errorf("failed to load tools: %w", err) + } + + // Load resources + if err := c.loadResources(ctx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load resources: %v", err)) + } + + // Load prompts + if err := c.loadPrompts(ctx); err != nil { + c.log(LoggingLevelWarning, fmt.Sprintf("Failed to load prompts: %v", err)) + } + + return nil +} + +// loadTools loads available tools from the MCP server +func (c *StdioClient) loadTools(ctx context.Context) error { + if c.mcpClient == nil { + return fmt.Errorf("client not connected") + } + + toolsResult, err := c.mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + + // Apply tool filtering + filteredTools := FilterTools(toolsResult.Tools, c.config.Options.ToolFilter) + + c.tools = filteredTools + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d tools (filtered from %d)", len(c.tools), len(toolsResult.Tools))) + + return nil +} + +// loadResources loads available resources from the MCP server +func (c *StdioClient) loadResources(ctx context.Context) error { + if c.mcpClient == nil { + return fmt.Errorf("client not connected") + } + + resourcesResult, err := c.mcpClient.ListResources(ctx, mcp.ListResourcesRequest{}) + if err != nil { + return fmt.Errorf("failed to list resources: %w", err) + } + + c.resources = resourcesResult.Resources + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d resources", len(c.resources))) + + return nil +} + +// loadPrompts loads available prompts from the MCP server +func (c *StdioClient) loadPrompts(ctx context.Context) error { + if c.mcpClient == nil { + return fmt.Errorf("client not connected") + } + + promptsResult, err := c.mcpClient.ListPrompts(ctx, mcp.ListPromptsRequest{}) + if err != nil { + return fmt.Errorf("failed to list prompts: %w", err) + } + + c.prompts = promptsResult.Prompts + c.log(LoggingLevelInfo, fmt.Sprintf("Loaded %d prompts", len(c.prompts))) + + return nil +} + +// CallTool calls a tool on the MCP server +func (c *StdioClient) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + if c.mcpClient == nil { + return nil, fmt.Errorf("client not connected") + } + + // Check if tool exists and is allowed + var toolFound bool + for _, tool := range c.tools { + if tool.Name == name { + toolFound = true + break + } + } + + if !toolFound { + return nil, fmt.Errorf("tool '%s' not found or not allowed", name) + } + + // Validate that the tool hasn't been flagged as malicious + c.log(LoggingLevelDebug, fmt.Sprintf("Calling tool: %s with arguments: %+v", name, arguments)) + + // Log the type of each argument for debugging + for key, value := range arguments { + c.log(LoggingLevelDebug, fmt.Sprintf("Argument %s: type=%T, value=%+v", key, value, value)) + } + + // For debugging, keep the original arguments as-is for the MCP call + // The MCP library should handle the conversion properly + callReq := mcp.CallToolRequest{} + callReq.Params.Name = name + callReq.Params.Arguments = arguments + + result, err := c.mcpClient.CallTool(ctx, callReq) + + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Tool call failed for %s: %v", name, err)) + c.log(LoggingLevelError, fmt.Sprintf("Arguments were: %+v", arguments)) + + // Log detailed argument types for debugging + for key, value := range arguments { + c.log(LoggingLevelError, fmt.Sprintf("Failed arg %s: type=%T, value=%+v", key, value, value)) + } + + return nil, fmt.Errorf("tool call failed: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Tool call successful: %s", name)) + if result != nil && result.Content != nil { + c.log(LoggingLevelDebug, fmt.Sprintf("Tool result content length: %d items", len(result.Content))) + } + return result, nil +} + +// ReadResource reads a resource from the MCP server +func (c *StdioClient) ReadResource(ctx context.Context, uri string) (*mcp.ReadResourceResult, error) { + if c.mcpClient == nil { + return nil, fmt.Errorf("client not connected") + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Reading resource: %s", uri)) + + readReq := mcp.ReadResourceRequest{} + readReq.Params.URI = uri + + result, err := c.mcpClient.ReadResource(ctx, readReq) + + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Resource read failed: %v", err)) + return nil, fmt.Errorf("resource read failed: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Resource read successful: %s", uri)) + return result, nil +} + +// GetPrompt gets a prompt from the MCP server +func (c *StdioClient) GetPrompt(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.GetPromptResult, error) { + if c.mcpClient == nil { + return nil, fmt.Errorf("client not connected") + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Getting prompt: %s", name)) + + // Convert map[string]interface{} to map[string]string for MCP library compatibility + stringArgs := make(map[string]string) + for k, v := range arguments { + if str, ok := v.(string); ok { + stringArgs[k] = str + } else { + stringArgs[k] = fmt.Sprintf("%v", v) + } + } + + getPromptReq := mcp.GetPromptRequest{} + getPromptReq.Params.Name = name + getPromptReq.Params.Arguments = stringArgs + + result, err := c.mcpClient.GetPrompt(ctx, getPromptReq) + + if err != nil { + c.log(LoggingLevelError, fmt.Sprintf("Prompt get failed: %v", err)) + return nil, fmt.Errorf("prompt get failed: %w", err) + } + + c.log(LoggingLevelDebug, fmt.Sprintf("Prompt get successful: %s", name)) + return result, nil +} + +// Ping sends a ping to the MCP server +func (c *StdioClient) Ping(ctx context.Context) error { + if c.mcpClient == nil { + return fmt.Errorf("client not connected") + } + + err := c.mcpClient.Ping(ctx) + return err +} + +// RefreshCapabilities reloads tools, resources, and prompts +func (c *StdioClient) RefreshCapabilities(ctx context.Context) error { + if !c.connected { + return fmt.Errorf("client not connected") + } + + return c.loadCapabilities(ctx) +} + +// Close closes the connection to the MCP server +func (c *StdioClient) Close() error { + var err error + + // Close MCP client if it exists + if c.mcpClient != nil { + err = c.mcpClient.Close() + c.mcpClient = nil + } + + c.connected = false + c.log(LoggingLevelInfo, "Disconnected from MCP server") + + return err +} diff --git a/src/semantic-router/pkg/connectivity/mcp/types.go b/src/semantic-router/pkg/connectivity/mcp/types.go new file mode 100644 index 00000000..99a2d786 --- /dev/null +++ b/src/semantic-router/pkg/connectivity/mcp/types.go @@ -0,0 +1,590 @@ +package mcp + +import ( + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// JSONRPCVersion represents the JSON-RPC version +const JSONRPCVersion = "2.0" + +// JSONRPCRequest represents a JSON-RPC request +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +// JSONRPCResponse represents a JSON-RPC response +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Result interface{} `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// JSONRPCError represents a JSON-RPC error +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ServerInfo represents MCP server information +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` + Protocol ProtocolVersion `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + Instructions string `json:"instructions,omitempty"` +} + +// ProtocolVersion represents the MCP protocol version +type ProtocolVersion struct { + Version string `json:"version"` +} + +// ServerCapabilities represents what the server can do +type ServerCapabilities struct { + Logging *LoggingCapability `json:"logging,omitempty"` + Prompts *PromptsCapability `json:"prompts,omitempty"` + Resources *ResourcesCapability `json:"resources,omitempty"` + Tools *ToolsCapability `json:"tools,omitempty"` + Sampling *SamplingCapability `json:"sampling,omitempty"` +} + +// LoggingCapability represents logging capabilities +type LoggingCapability struct{} + +// PromptsCapability represents prompts capabilities +type PromptsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +// ResourcesCapability represents resources capabilities +type ResourcesCapability struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` +} + +// ToolsCapability represents tools capabilities +type ToolsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +// SamplingCapability represents sampling capabilities +type SamplingCapability struct{} + +// ClientCapabilities represents client capabilities +type ClientCapabilities struct { + Roots *RootsCapability `json:"roots,omitempty"` + Sampling *SamplingCapability `json:"sampling,omitempty"` +} + +// RootsCapability represents roots capabilities +type RootsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +// InitializeRequest represents an initialize request +type InitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo ClientInfo `json:"clientInfo"` +} + +// ClientInfo represents client information +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResult represents the result of initialization +type InitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` + Instructions string `json:"instructions,omitempty"` +} + +// Tool represents an MCP tool +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema ToolSchema `json:"inputSchema"` +} + +// ToolSchema represents a tool's input schema +type ToolSchema struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +// ToolCallRequest represents a tool call request +type ToolCallRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// ToolResult represents the result of a tool call +type ToolResult struct { + Content []ContentBlock `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// ContentBlock represents a content block +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// TextContent represents text content +type TextContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// ImageContent represents image content +type ImageContent struct { + Type string `json:"type"` + Data string `json:"data"` + MimeType string `json:"mimeType"` +} + +// Resource represents an MCP resource +type Resource struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +// ResourceContents represents resource contents +type ResourceContents struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Contents []ContentBlock `json:"contents"` +} + +// Prompt represents an MCP prompt +type Prompt struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// PromptArgument represents a prompt argument +type PromptArgument struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +// PromptMessage represents a prompt message +type PromptMessage struct { + Role string `json:"role"` + Content []ContentBlock `json:"content"` +} + +// GetPromptResult represents the result of getting a prompt +type GetPromptResult struct { + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// ListToolsRequest represents a list tools request +type ListToolsRequest struct { + Cursor string `json:"cursor,omitempty"` +} + +// ListToolsResult represents the result of listing tools +type ListToolsResult struct { + Tools []Tool `json:"tools"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// ListResourcesRequest represents a list resources request +type ListResourcesRequest struct { + Cursor string `json:"cursor,omitempty"` +} + +// ListResourcesResult represents the result of listing resources +type ListResourcesResult struct { + Resources []Resource `json:"resources"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// ListPromptsRequest represents a list prompts request +type ListPromptsRequest struct { + Cursor string `json:"cursor,omitempty"` +} + +// ListPromptsResult represents the result of listing prompts +type ListPromptsResult struct { + Prompts []Prompt `json:"prompts"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// ReadResourceRequest represents a read resource request +type ReadResourceRequest struct { + URI string `json:"uri"` +} + +// ReadResourceResult represents the result of reading a resource +type ReadResourceResult struct { + Contents []ResourceContents `json:"contents"` +} + +// GetPromptRequest represents a get prompt request +type GetPromptRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// LoggingLevel represents logging levels +type LoggingLevel string + +const ( + LoggingLevelDebug LoggingLevel = "debug" + LoggingLevelInfo LoggingLevel = "info" + LoggingLevelNotice LoggingLevel = "notice" + LoggingLevelWarning LoggingLevel = "warning" + LoggingLevelError LoggingLevel = "error" + LoggingLevelCritical LoggingLevel = "critical" + LoggingLevelAlert LoggingLevel = "alert" + LoggingLevelEmergency LoggingLevel = "emergency" +) + +// LoggingMessage represents a logging message +type LoggingMessage struct { + Level LoggingLevel `json:"level"` + Data interface{} `json:"data"` + Logger string `json:"logger,omitempty"` +} + +// Notification represents an MCP notification +type Notification struct { + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +// Progress represents progress information +type Progress struct { + ProgressToken interface{} `json:"progressToken,omitempty"` + Progress int `json:"progress,omitempty"` + Total int `json:"total,omitempty"` +} + +// PaginatedRequest represents a paginated request +type PaginatedRequest struct { + Cursor string `json:"cursor,omitempty"` +} + +// PaginatedResult represents a paginated result +type PaginatedResult struct { + NextCursor string `json:"nextCursor,omitempty"` +} + +// MCPError represents standard MCP error codes +type MCPError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Standard MCP error codes +const ( + // JSON-RPC standard errors + ParseError = -32700 + InvalidRequest = -32600 + MethodNotFound = -32601 + InvalidParams = -32602 + InternalError = -32603 + + // MCP specific errors + InvalidTool = -32000 + InvalidResource = -32001 + InvalidPrompt = -32002 + ResourceNotFound = -32003 + ToolNotFound = -32004 + PromptNotFound = -32005 +) + +// Common MCP methods +const ( + MethodInitialize = "initialize" + MethodInitialized = "notifications/initialized" + MethodPing = "ping" + MethodListTools = "tools/list" + MethodCallTool = "tools/call" + MethodListResources = "resources/list" + MethodReadResource = "resources/read" + MethodListPrompts = "prompts/list" + MethodGetPrompt = "prompts/get" + MethodLogging = "notifications/message" + MethodToolsChanged = "notifications/tools/list_changed" + MethodResourcesChanged = "notifications/resources/list_changed" + MethodPromptsChanged = "notifications/prompts/list_changed" +) + +// CreateJSONRPCRequest creates a new JSON-RPC request +func CreateJSONRPCRequest(id interface{}, method string, params interface{}) *JSONRPCRequest { + return &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: id, + Method: method, + Params: params, + } +} + +// CreateJSONRPCResponse creates a new JSON-RPC response +func CreateJSONRPCResponse(id interface{}, result interface{}) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: id, + Result: result, + } +} + +// CreateJSONRPCError creates a new JSON-RPC error response +func CreateJSONRPCError(id interface{}, code int, message string, data interface{}) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: id, + Error: &JSONRPCError{ + Code: code, + Message: message, + Data: data, + }, + } +} + +// CreateNotification creates a new notification +func CreateNotification(method string, params interface{}) *JSONRPCRequest { + return &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + Method: method, + Params: params, + } +} + +// TextContentBlock creates a text content block +func TextContentBlock(text string) ContentBlock { + return ContentBlock{ + Type: "text", + Text: text, + } +} + +// ImageContentBlock creates an image content block +func ImageContentBlock(base64Data, mimeType string) ContentBlock { + return ContentBlock{ + Type: "image", + Text: base64Data, // base64-encoded image data + } +} + +// CreateToolResult creates a tool result +func CreateToolResult(content []ContentBlock, isError bool) *ToolResult { + return &ToolResult{ + Content: content, + IsError: isError, + } +} + +// Helper functions for OpenAI compatibility + +// These types are defined in client.go to avoid duplication + +// OpenAITool represents an OpenAI-compatible tool definition +type OpenAITool struct { + Type string `json:"type"` + Function OpenAIToolFunction `json:"function"` +} + +// OpenAIToolFunction represents the function part of an OpenAI tool +type OpenAIToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]interface{} `json:"parameters,omitempty"` +} + +// OpenAIToolCall represents an OpenAI tool call +type OpenAIToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function OpenAIToolCallFunction `json:"function"` +} + +// OpenAIToolCallFunction represents the function call part +type OpenAIToolCallFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ConvertToolToOpenAI converts an MCP tool to OpenAI format +func ConvertToolToOpenAI(tool mcp.Tool) OpenAITool { + openAITool := OpenAITool{ + Type: "function", + Function: OpenAIToolFunction{ + Name: tool.Name, + Description: tool.Description, + }, + } + + // Convert input schema to OpenAI parameters format + if tool.InputSchema.Type != "" || len(tool.InputSchema.Properties) > 0 { + openAITool.Function.Parameters = map[string]interface{}{ + "type": "object", + "properties": tool.InputSchema.Properties, + "required": tool.InputSchema.Required, + } + } + + return openAITool +} + +// ConvertOpenAIToMCPCall converts an OpenAI tool call to MCP format +func ConvertOpenAIToMCPCall(openAICall OpenAIToolCall) (mcp.CallToolRequest, error) { + var arguments map[string]interface{} + if openAICall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(openAICall.Function.Arguments), &arguments); err != nil { + argStr := openAICall.Function.Arguments + const maxLen = 200 + if len(argStr) > maxLen { + argStr = argStr[:maxLen] + "...(truncated)" + } + return mcp.CallToolRequest{}, fmt.Errorf("failed to parse arguments (%q): %w", argStr, err) + } + } + + // Convert map[string]interface{} to map[string]string for MCP library compatibility + stringArgs := make(map[string]string) + for k, v := range arguments { + if str, ok := v.(string); ok { + stringArgs[k] = str + } else { + stringArgs[k] = fmt.Sprintf("%v", v) + } + } + + // Convert string args to map[string]interface{} for MCP API + interfaceArgs := make(map[string]interface{}) + for k, v := range stringArgs { + interfaceArgs[k] = v + } + + req := mcp.CallToolRequest{} + req.Params.Name = openAICall.Function.Name + req.Params.Arguments = interfaceArgs + return req, nil +} + +// ConvertMCPResultToOpenAI converts an MCP tool result to OpenAI format +func ConvertMCPResultToOpenAI(result *mcp.CallToolResult) map[string]interface{} { + if result == nil { + return map[string]interface{}{ + "content": "No result", + } + } + + content := "" + if len(result.Content) > 0 { + firstContent := result.Content[0] + + // Use a type switch to match the actual types from mcp-go + switch c := firstContent.(type) { + case *mcp.TextContent: + content = c.Text + case mcp.TextContent: + content = c.Text + case *mcp.ImageContent: + content = fmt.Sprintf("Image: %s", c.Data) + case mcp.ImageContent: + content = fmt.Sprintf("Image: %s", c.Data) + case *mcp.EmbeddedResource: + content = fmt.Sprintf("Resource: %v", c.Resource) + case mcp.EmbeddedResource: + content = fmt.Sprintf("Resource: %v", c.Resource) + default: + // Fallback: try to get string representation + content = fmt.Sprintf("%v", firstContent) + } + } + + return map[string]interface{}{ + "content": content, + "isError": result.IsError, + } +} + +// FilterTools applies tool filtering based on the configuration +func FilterTools(tools []mcp.Tool, filter ToolFilter) []mcp.Tool { + if filter.Mode == "" || len(filter.List) == 0 { + return tools + } + + filtered := make([]mcp.Tool, 0) + + for _, tool := range tools { + shouldInclude := false + + if filter.Mode == "allow" { + // In allow mode, only include tools in the list + for _, allowedTool := range filter.List { + if tool.Name == allowedTool { + shouldInclude = true + break + } + } + } else if filter.Mode == "block" { + // In block mode, include all tools except those in the list + shouldInclude = true + for _, blockedTool := range filter.List { + if tool.Name == blockedTool { + shouldInclude = false + break + } + } + } + + if shouldInclude { + filtered = append(filtered, tool) + } + } + + return filtered +} + +// ValidateAuthToken validates an authentication token +func ValidateAuthToken(token string, validTokens []string) bool { + if len(validTokens) == 0 { + return true // No authentication required + } + + for _, validToken := range validTokens { + if token == validToken { + return true + } + } + + return false +} + +// ExtractAuthToken extracts the authentication token from an Authorization header +func ExtractAuthToken(authHeader string) string { + if authHeader == "" { + return "" + } + + // Support both "Bearer " and "" formats + if len(authHeader) > 7 && authHeader[:7] == "Bearer " { + return authHeader[7:] + } + + return authHeader +} diff --git a/src/semantic-router/pkg/extproc/test_utils_test.go b/src/semantic-router/pkg/extproc/test_utils_test.go index 2f53a762..3492a402 100644 --- a/src/semantic-router/pkg/extproc/test_utils_test.go +++ b/src/semantic-router/pkg/extproc/test_utils_test.go @@ -88,6 +88,17 @@ func CreateTestConfig() *config.RouterConfig { UseModernBERT bool `yaml:"use_modernbert"` CategoryMappingPath string `yaml:"category_mapping_path"` } `yaml:"category_model"` + MCPCategoryModel struct { + Enabled bool `yaml:"enabled"` + TransportType string `yaml:"transport_type"` + Command string `yaml:"command,omitempty"` + Args []string `yaml:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty"` + URL string `yaml:"url,omitempty"` + ToolName string `yaml:"tool_name,omitempty"` + Threshold float32 `yaml:"threshold"` + TimeoutSeconds int `yaml:"timeout_seconds,omitempty"` + } `yaml:"mcp_category_model,omitempty"` PIIModel struct { ModelID string `yaml:"model_id"` Threshold float32 `yaml:"threshold"` @@ -107,6 +118,19 @@ func CreateTestConfig() *config.RouterConfig { UseModernBERT: true, CategoryMappingPath: "../../../../models/category_classifier_modernbert-base_model/category_mapping.json", }, + MCPCategoryModel: struct { + Enabled bool `yaml:"enabled"` + TransportType string `yaml:"transport_type"` + Command string `yaml:"command,omitempty"` + Args []string `yaml:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty"` + URL string `yaml:"url,omitempty"` + ToolName string `yaml:"tool_name,omitempty"` + Threshold float32 `yaml:"threshold"` + TimeoutSeconds int `yaml:"timeout_seconds,omitempty"` + }{ + Enabled: false, // MCP not used in tests + }, PIIModel: struct { ModelID string `yaml:"model_id"` Threshold float32 `yaml:"threshold"` diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index 63890c61..c0a52a09 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -83,7 +83,19 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl func createLegacyClassifier(config *config.RouterConfig) (*classification.Classifier, error) { // Load category mapping var categoryMapping *classification.CategoryMapping - if config.Classifier.CategoryModel.CategoryMappingPath != "" { + + // Check if we should load categories from MCP server + // Note: tool_name is optional and will be auto-discovered if not specified + useMCPCategories := config.Classifier.CategoryModel.ModelID == "" && + config.Classifier.MCPCategoryModel.Enabled + + if useMCPCategories { + // Categories will be loaded from MCP server during initialization + observability.Infof("Category mapping will be loaded from MCP server") + // Create empty mapping initially - will be populated during initialization + categoryMapping = nil + } else if config.Classifier.CategoryModel.CategoryMappingPath != "" { + // Load from file as usual var err error categoryMapping, err = classification.LoadCategoryMapping(config.Classifier.CategoryModel.CategoryMappingPath) if err != nil { diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index eba241f2..be3efeeb 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -197,7 +197,7 @@ type PIIAnalysisResult struct { // Classifier handles text classification, model selection, and jailbreak detection functionality type Classifier struct { - // Dependencies + // Dependencies - In-tree classifiers categoryInitializer CategoryInitializer categoryInference CategoryInference jailbreakInitializer JailbreakInitializer @@ -205,6 +205,10 @@ type Classifier struct { piiInitializer PIIInitializer piiInference PIIInference + // Dependencies - MCP-based classifiers + mcpCategoryInitializer MCPCategoryInitializer + mcpCategoryInference MCPCategoryInference + Config *config.RouterConfig CategoryMapping *CategoryMapping PIIMapping *PIIMapping @@ -245,10 +249,15 @@ func withPII(piiMapping *PIIMapping, piiInitializer PIIInitializer, piiInference // initModels initializes the models for the classifier func initModels(classifier *Classifier) (*Classifier, error) { + // Initialize either in-tree OR MCP-based category classifier if classifier.IsCategoryEnabled() { if err := classifier.initializeCategoryClassifier(); err != nil { return nil, err } + } else if classifier.IsMCPCategoryEnabled() { + if err := classifier.initializeMCPCategoryClassifier(); err != nil { + return nil, err + } } if classifier.IsJailbreakEnabled() { @@ -284,14 +293,32 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla return initModels(classifier) } -// NewClassifier creates a new classifier with model selection and jailbreak/PII detection capabilities +// NewClassifier creates a new classifier with model selection and jailbreak/PII detection capabilities. +// Both in-tree and MCP classifiers can be configured simultaneously for category classification. +// At runtime, in-tree classifier will be tried first, with MCP as a fallback, +// allowing flexible deployment scenarios such as gradual migration. func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) (*Classifier, error) { - return newClassifierWithOptions( - cfg, - withCategory(categoryMapping, createCategoryInitializer(cfg.Classifier.CategoryModel.UseModernBERT), createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT)), + options := []option{ withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT), createJailbreakInference(cfg.PromptGuard.UseModernBERT)), withPII(piiMapping, createPIIInitializer(), createPIIInference()), - ) + } + + // Add in-tree classifier if configured + if cfg.Classifier.CategoryModel.ModelID != "" { + options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.Classifier.CategoryModel.UseModernBERT), createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT))) + } + + // Add MCP classifier if configured + // Note: Both in-tree and MCP classifiers can be configured simultaneously. + // At runtime, in-tree classifier will be tried first, with MCP as a fallback. + // This allows flexible deployment scenarios (e.g., gradual migration, A/B testing). + if cfg.Classifier.MCPCategoryModel.Enabled { + mcpInit := createMCPCategoryInitializer() + mcpInf := createMCPCategoryInference(mcpInit) + options = append(options, withMCPCategory(mcpInit, mcpInf)) + } + + return newClassifierWithOptions(cfg, options...) } // IsCategoryEnabled checks if category classification is properly configured @@ -315,6 +342,26 @@ func (c *Classifier) initializeCategoryClassifier() error { // ClassifyCategory performs category classification on the given text func (c *Classifier) ClassifyCategory(text string) (string, float64, error) { + // Try in-tree first if properly configured + if c.IsCategoryEnabled() && c.categoryInference != nil { + return c.classifyCategoryInTree(text) + } + + // If in-tree classifier was initialized but config is now invalid, return specific error + if c.categoryInference != nil && !c.IsCategoryEnabled() { + return "", 0.0, fmt.Errorf("category classification is not properly configured") + } + + // Fall back to MCP + if c.IsMCPCategoryEnabled() && c.mcpCategoryInference != nil { + return c.classifyCategoryMCP(text) + } + + return "", 0.0, fmt.Errorf("no category classification method available") +} + +// classifyCategoryInTree performs category classification using in-tree model +func (c *Classifier) classifyCategoryInTree(text string) (string, float64, error) { if !c.IsCategoryEnabled() { return "", 0.0, fmt.Errorf("category classification is not properly configured") } @@ -478,6 +525,26 @@ func (c *Classifier) initializePIIClassifier() error { // ClassifyCategoryWithEntropy performs category classification with entropy-based reasoning decision func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64, entropy.ReasoningDecision, error) { + // Try in-tree first if properly configured + if c.IsCategoryEnabled() && c.categoryInference != nil { + return c.classifyCategoryWithEntropyInTree(text) + } + + // If in-tree classifier was initialized but config is now invalid, return specific error + if c.categoryInference != nil && !c.IsCategoryEnabled() { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("category classification is not properly configured") + } + + // Fall back to MCP + if c.IsMCPCategoryEnabled() && c.mcpCategoryInference != nil { + return c.classifyCategoryWithEntropyMCP(text) + } + + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("no category classification method available") +} + +// classifyCategoryWithEntropyInTree performs category classification with entropy using in-tree model +func (c *Classifier) classifyCategoryWithEntropyInTree(text string) (string, float64, entropy.ReasoningDecision, error) { if !c.IsCategoryEnabled() { return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("category classification is not properly configured") } diff --git a/src/semantic-router/pkg/utils/classification/mcp_classifier.go b/src/semantic-router/pkg/utils/classification/mcp_classifier.go new file mode 100644 index 00000000..a9587db1 --- /dev/null +++ b/src/semantic-router/pkg/utils/classification/mcp_classifier.go @@ -0,0 +1,677 @@ +package classification + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" + candle_binding "github.com/vllm-project/semantic-router/candle-binding" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + mcpclient "github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp/api" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/entropy" +) + +const ( + // DefaultMCPThreshold is the default confidence threshold for MCP classification. + // For multi-class classification, a value of 0.5 means that a predicted class must have at least 50% confidence + // to be selected. Adjust this threshold as needed for your use case. + DefaultMCPThreshold = 0.5 +) + +// MCPClassificationResult holds the classification result with routing information from MCP server +type MCPClassificationResult struct { + Class int + Confidence float32 + CategoryName string + Model string // Model recommended by MCP server + UseReasoning *bool // Whether to use reasoning (nil means use default) +} + +// MCPCategoryInitializer initializes MCP connection for category classification +type MCPCategoryInitializer interface { + Init(cfg *config.RouterConfig) error + Close() error +} + +// MCPCategoryInference performs classification via MCP +type MCPCategoryInference interface { + Classify(ctx context.Context, text string) (candle_binding.ClassResult, error) + ClassifyWithProbabilities(ctx context.Context, text string) (candle_binding.ClassResultWithProbs, error) + ListCategories(ctx context.Context) (*CategoryMapping, error) +} + +// MCPCategoryClassifier implements both MCPCategoryInitializer and MCPCategoryInference. +// +// Protocol Contract: +// This client relies on the MCP server to respect the protocol defined in the +// github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp/api package. +// +// The MCP server must implement these tools: +// 1. list_categories - Returns api.ListCategoriesResponse +// 2. classify_text - Returns api.ClassifyResponse or api.ClassifyWithProbabilitiesResponse +// +// The MCP server controls both classification AND routing decisions. When the server returns +// "model" and "use_reasoning" in the classification response, the router will use those values. +// If not provided, the router falls back to the default_model configuration. +// +// For detailed type definitions and examples, see the api package documentation. +type MCPCategoryClassifier struct { + client mcpclient.MCPClient + toolName string + config *config.RouterConfig +} + +// Init initializes the MCP client connection +func (m *MCPCategoryClassifier) Init(cfg *config.RouterConfig) error { + if cfg == nil { + return fmt.Errorf("config is nil") + } + + // Validate MCP configuration + if !cfg.Classifier.MCPCategoryModel.Enabled { + return fmt.Errorf("MCP category classifier is not enabled") + } + + // Store config + m.config = cfg + + // Create MCP client configuration + mcpConfig := mcpclient.ClientConfig{ + TransportType: cfg.Classifier.MCPCategoryModel.TransportType, + Command: cfg.Classifier.MCPCategoryModel.Command, + Args: cfg.Classifier.MCPCategoryModel.Args, + Env: cfg.Classifier.MCPCategoryModel.Env, + URL: cfg.Classifier.MCPCategoryModel.URL, + Options: mcpclient.ClientOptions{ + LogEnabled: true, + }, + } + + // Set timeout if specified + if cfg.Classifier.MCPCategoryModel.TimeoutSeconds > 0 { + mcpConfig.Timeout = time.Duration(cfg.Classifier.MCPCategoryModel.TimeoutSeconds) * time.Second + } + + // Create MCP client + client, err := mcpclient.NewClient("category_classifier", mcpConfig) + if err != nil { + return fmt.Errorf("failed to create MCP client: %w", err) + } + + // Connect to MCP server + if err := client.Connect(); err != nil { + return fmt.Errorf("failed to connect to MCP server: %w", err) + } + + m.client = client + + // Discover classification tool + if err := m.discoverClassificationTool(); err != nil { + client.Close() + return fmt.Errorf("failed to discover classification tool: %w", err) + } + + observability.Infof("Successfully initialized MCP category classifier with tool '%s'", m.toolName) + return nil +} + +// discoverClassificationTool finds the appropriate classification tool from available MCP tools +func (m *MCPCategoryClassifier) discoverClassificationTool() error { + // If tool name is explicitly specified, use it + if m.config.Classifier.MCPCategoryModel.ToolName != "" { + m.toolName = m.config.Classifier.MCPCategoryModel.ToolName + observability.Infof("Using explicitly configured tool: %s", m.toolName) + return nil + } + + // Otherwise, auto-discover by listing available tools + tools := m.client.GetTools() + if len(tools) == 0 { + return fmt.Errorf("no tools available from MCP server") + } + + // Look for classification-related tools by common names + classificationToolNames := []string{ + "classify_text", + "classify", + "categorize", + "categorize_text", + } + + for _, toolName := range classificationToolNames { + for _, tool := range tools { + if tool.Name == toolName { + m.toolName = tool.Name + observability.Infof("Auto-discovered classification tool: %s - %s", m.toolName, tool.Description) + return nil + } + } + } + + // If no common name found, look for tools that mention "classif" in name or description + for _, tool := range tools { + lowerName := strings.ToLower(tool.Name) + lowerDesc := strings.ToLower(tool.Description) + if strings.Contains(lowerName, "classif") || strings.Contains(lowerDesc, "classif") { + m.toolName = tool.Name + observability.Infof("Auto-discovered classification tool by pattern match: %s - %s", m.toolName, tool.Description) + return nil + } + } + + // Log available tools for debugging + var toolNames []string + for _, tool := range tools { + toolNames = append(toolNames, tool.Name) + } + return fmt.Errorf("no classification tool found among available tools: %v", toolNames) +} + +// Close closes the MCP client connection +func (m *MCPCategoryClassifier) Close() error { + if m.client != nil { + return m.client.Close() + } + return nil +} + +// Classify performs category classification via MCP +func (m *MCPCategoryClassifier) Classify(ctx context.Context, text string) (candle_binding.ClassResult, error) { + if m.client == nil { + return candle_binding.ClassResult{}, fmt.Errorf("MCP client not initialized") + } + + // Prepare arguments for MCP tool call + arguments := map[string]interface{}{ + "text": text, + } + + // Call MCP tool + result, err := m.client.CallTool(ctx, m.toolName, arguments) + if err != nil { + return candle_binding.ClassResult{}, fmt.Errorf("MCP tool call failed: %w", err) + } + + // Check for errors in result + if result.IsError { + return candle_binding.ClassResult{}, fmt.Errorf("MCP tool returned error: %v", result.Content) + } + + // Parse response from first content block + if len(result.Content) == 0 { + return candle_binding.ClassResult{}, fmt.Errorf("MCP tool returned empty content") + } + + // Extract text content + var responseText string + firstContent := result.Content[0] + if textContent, ok := mcp.AsTextContent(firstContent); ok { + responseText = textContent.Text + } else { + return candle_binding.ClassResult{}, fmt.Errorf("MCP tool returned non-text content") + } + + // Parse JSON response using the API type + var response api.ClassifyResponse + if err := json.Unmarshal([]byte(responseText), &response); err != nil { + return candle_binding.ClassResult{}, fmt.Errorf("failed to parse MCP response: %w", err) + } + + return candle_binding.ClassResult{ + Class: response.Class, + Confidence: response.Confidence, + }, nil +} + +// ClassifyWithProbabilities performs category classification with full probability distribution via MCP +func (m *MCPCategoryClassifier) ClassifyWithProbabilities(ctx context.Context, text string) (candle_binding.ClassResultWithProbs, error) { + if m.client == nil { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("MCP client not initialized") + } + + // Prepare arguments for MCP tool call with probabilities request + arguments := map[string]interface{}{ + "text": text, + "with_probabilities": true, + } + + // Call MCP tool + result, err := m.client.CallTool(ctx, m.toolName, arguments) + if err != nil { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("MCP tool call failed: %w", err) + } + + // Check for errors in result + if result.IsError { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("MCP tool returned error: %v", result.Content) + } + + // Parse response from first content block + if len(result.Content) == 0 { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("MCP tool returned empty content") + } + + // Extract text content + var responseText string + firstContent := result.Content[0] + if textContent, ok := mcp.AsTextContent(firstContent); ok { + responseText = textContent.Text + } else { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("MCP tool returned non-text content") + } + + // Parse JSON response using the API type + var response api.ClassifyWithProbabilitiesResponse + if err := json.Unmarshal([]byte(responseText), &response); err != nil { + return candle_binding.ClassResultWithProbs{}, fmt.Errorf("failed to parse MCP response: %w", err) + } + + return candle_binding.ClassResultWithProbs{ + Class: response.Class, + Confidence: response.Confidence, + Probabilities: response.Probabilities, + }, nil +} + +// ListCategories retrieves the category mapping from the MCP server +func (m *MCPCategoryClassifier) ListCategories(ctx context.Context) (*CategoryMapping, error) { + if m.client == nil { + return nil, fmt.Errorf("MCP client not initialized") + } + + // Call the list_categories tool + result, err := m.client.CallTool(ctx, "list_categories", map[string]interface{}{}) + if err != nil { + return nil, fmt.Errorf("MCP list_categories call failed: %w", err) + } + + // Check for errors in result + if result.IsError { + return nil, fmt.Errorf("MCP tool returned error: %v", result.Content) + } + + // Parse response from first content block + if len(result.Content) == 0 { + return nil, fmt.Errorf("MCP tool returned empty content") + } + + // Extract text content + var responseText string + firstContent := result.Content[0] + if textContent, ok := mcp.AsTextContent(firstContent); ok { + responseText = textContent.Text + } else { + return nil, fmt.Errorf("MCP tool returned non-text content") + } + + // Parse JSON response using the API type + var response api.ListCategoriesResponse + if err := json.Unmarshal([]byte(responseText), &response); err != nil { + return nil, fmt.Errorf("failed to parse MCP categories response: %w", err) + } + + // Build CategoryMapping from the list + mapping := &CategoryMapping{ + CategoryToIdx: make(map[string]int), + IdxToCategory: make(map[string]string), + } + + for idx, category := range response.Categories { + mapping.CategoryToIdx[category] = idx + mapping.IdxToCategory[fmt.Sprintf("%d", idx)] = category + } + + observability.Infof("Loaded %d categories from MCP server: %v", len(response.Categories), response.Categories) + return mapping, nil +} + +// createMCPCategoryInitializer creates an MCP category initializer +func createMCPCategoryInitializer() MCPCategoryInitializer { + return &MCPCategoryClassifier{} +} + +// createMCPCategoryInference creates an MCP category inference from the initializer +func createMCPCategoryInference(initializer MCPCategoryInitializer) MCPCategoryInference { + if classifier, ok := initializer.(*MCPCategoryClassifier); ok { + return classifier + } + return nil +} + +// IsMCPCategoryEnabled checks if MCP-based category classification is properly configured. +// Note: tool_name is optional and will be auto-discovered during initialization if not specified. +func (c *Classifier) IsMCPCategoryEnabled() bool { + return c.Config.Classifier.MCPCategoryModel.Enabled +} + +// initializeMCPCategoryClassifier initializes the MCP category classification model +func (c *Classifier) initializeMCPCategoryClassifier() error { + if !c.IsMCPCategoryEnabled() { + return fmt.Errorf("MCP category classification is not properly configured") + } + + if c.mcpCategoryInitializer == nil { + return fmt.Errorf("MCP category initializer is not set") + } + + if err := c.mcpCategoryInitializer.Init(c.Config); err != nil { + return fmt.Errorf("failed to initialize MCP category classifier: %w", err) + } + + // If no in-tree category model is configured and no category mapping exists, + // load categories from the MCP server + if c.Config.Classifier.CategoryModel.ModelID == "" && c.CategoryMapping == nil { + observability.Infof("Loading category mapping from MCP server...") + + // Create a context with timeout for the list_categories call + ctx := context.Background() + if c.Config.Classifier.MCPCategoryModel.TimeoutSeconds > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.Classifier.MCPCategoryModel.TimeoutSeconds)*time.Second) + defer cancel() + } + + // Get categories from MCP server + categoryMapping, err := c.mcpCategoryInference.ListCategories(ctx) + if err != nil { + return fmt.Errorf("failed to load categories from MCP server: %w", err) + } + + // Store the category mapping + c.CategoryMapping = categoryMapping + observability.Infof("Successfully loaded %d categories from MCP server", c.CategoryMapping.GetCategoryCount()) + } + + observability.Infof("Successfully initialized MCP category classifier") + return nil +} + +// classifyCategoryMCP performs category classification using MCP +func (c *Classifier) classifyCategoryMCP(text string) (string, float64, error) { + result, err := c.classifyCategoryMCPWithRouting(text) + if err != nil { + return "", 0.0, err + } + return result.CategoryName, float64(result.Confidence), nil +} + +// classifyCategoryMCPWithRouting performs category classification using MCP and returns routing information +func (c *Classifier) classifyCategoryMCPWithRouting(text string) (*MCPClassificationResult, error) { + if !c.IsMCPCategoryEnabled() { + return nil, fmt.Errorf("MCP category classification is not properly configured") + } + + if c.mcpCategoryInference == nil { + return nil, fmt.Errorf("MCP category inference is not initialized") + } + + // Create context with timeout + ctx := context.Background() + if c.Config.Classifier.MCPCategoryModel.TimeoutSeconds > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.Classifier.MCPCategoryModel.TimeoutSeconds)*time.Second) + defer cancel() + } + + // Classify via MCP - need to call the raw client to get model/reasoning info + start := time.Now() + + // Get MCP classifier to access raw response + mcpClassifier, ok := c.mcpCategoryInference.(*MCPCategoryClassifier) + if !ok { + return nil, fmt.Errorf("MCP category inference is not MCPCategoryClassifier type") + } + + // Call MCP tool directly to get full response + arguments := map[string]interface{}{ + "text": text, + } + + mcpResult, err := mcpClassifier.client.CallTool(ctx, mcpClassifier.toolName, arguments) + metrics.RecordClassifierLatency("category_mcp", time.Since(start).Seconds()) + + if err != nil { + return nil, fmt.Errorf("MCP tool call failed: %w", err) + } + + if mcpResult.IsError { + return nil, fmt.Errorf("MCP tool returned error: %v", mcpResult.Content) + } + + if len(mcpResult.Content) == 0 { + return nil, fmt.Errorf("MCP tool returned empty content") + } + + // Extract text content + var responseText string + firstContent := mcpResult.Content[0] + if textContent, ok := mcp.AsTextContent(firstContent); ok { + responseText = textContent.Text + } else { + return nil, fmt.Errorf("MCP tool returned non-text content") + } + + // Parse JSON response with routing information using the API type + var response api.ClassifyResponse + if err := json.Unmarshal([]byte(responseText), &response); err != nil { + return nil, fmt.Errorf("failed to parse MCP response: %w", err) + } + + observability.Infof("MCP classification result: class=%d, confidence=%.4f, model=%s, use_reasoning=%v", + response.Class, response.Confidence, response.Model, response.UseReasoning) + + // Check threshold + threshold := c.Config.Classifier.MCPCategoryModel.Threshold + if threshold == 0 { + threshold = DefaultMCPThreshold + } + + if response.Confidence < threshold { + observability.Infof("MCP classification confidence (%.4f) below threshold (%.4f)", + response.Confidence, threshold) + return &MCPClassificationResult{ + Class: response.Class, + Confidence: response.Confidence, + Model: response.Model, + UseReasoning: response.UseReasoning, + }, nil + } + + // Map class index to category name + var categoryName string + if c.CategoryMapping != nil { + name, ok := c.CategoryMapping.GetCategoryFromIndex(response.Class) + if ok { + categoryName = c.translateMMLUToGeneric(name) + } else { + categoryName = fmt.Sprintf("category_%d", response.Class) + } + } else { + categoryName = fmt.Sprintf("category_%d", response.Class) + } + + metrics.RecordCategoryClassification(categoryName) + observability.Infof("MCP classified as category: %s (class=%d), routing: model=%s, reasoning=%v", + categoryName, response.Class, response.Model, response.UseReasoning) + + return &MCPClassificationResult{ + Class: response.Class, + Confidence: response.Confidence, + CategoryName: categoryName, + Model: response.Model, + UseReasoning: response.UseReasoning, + }, nil +} + +// classifyCategoryWithEntropyMCP performs category classification with entropy using MCP +func (c *Classifier) classifyCategoryWithEntropyMCP(text string) (string, float64, entropy.ReasoningDecision, error) { + if !c.IsMCPCategoryEnabled() { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("MCP category classification is not properly configured") + } + + if c.mcpCategoryInference == nil { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("MCP category inference is not initialized") + } + + // Create context with timeout + ctx := context.Background() + if c.Config.Classifier.MCPCategoryModel.TimeoutSeconds > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(c.Config.Classifier.MCPCategoryModel.TimeoutSeconds)*time.Second) + defer cancel() + } + + // Get full probability distribution via MCP + start := time.Now() + result, err := c.mcpCategoryInference.ClassifyWithProbabilities(ctx, text) + metrics.RecordClassifierLatency("category_mcp", time.Since(start).Seconds()) + + if err != nil { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("MCP classification error: %w", err) + } + + observability.Infof("MCP classification result: class=%d, confidence=%.4f, entropy_available=%t", + result.Class, result.Confidence, len(result.Probabilities) > 0) + + // Get category names for all classes and translate to generic names when configured + categoryNames := make([]string, len(result.Probabilities)) + for i := range result.Probabilities { + if c.CategoryMapping != nil { + if name, ok := c.CategoryMapping.GetCategoryFromIndex(i); ok { + categoryNames[i] = c.translateMMLUToGeneric(name) + } else { + categoryNames[i] = fmt.Sprintf("unknown_%d", i) + } + } else { + categoryNames[i] = fmt.Sprintf("category_%d", i) + } + } + + // Build category reasoning map from configuration + categoryReasoningMap := make(map[string]bool) + for _, category := range c.Config.Categories { + useReasoning := false + if len(category.ModelScores) > 0 && category.ModelScores[0].UseReasoning != nil { + useReasoning = *category.ModelScores[0].UseReasoning + } + categoryReasoningMap[strings.ToLower(category.Name)] = useReasoning + } + + // Determine threshold + threshold := c.Config.Classifier.MCPCategoryModel.Threshold + if threshold == 0 { + threshold = DefaultMCPThreshold + } + + // Make entropy-based reasoning decision + entropyStart := time.Now() + reasoningDecision := entropy.MakeEntropyBasedReasoningDecision( + result.Probabilities, + categoryNames, + categoryReasoningMap, + float64(threshold), + ) + entropyLatency := time.Since(entropyStart).Seconds() + + // Calculate entropy value for metrics + entropyValue := entropy.CalculateEntropy(result.Probabilities) + + // Determine top category for metrics + topCategory := "none" + if len(reasoningDecision.TopCategories) > 0 { + topCategory = reasoningDecision.TopCategories[0].Category + } + + // Validate probability distribution quality + probSum := float32(0.0) + for _, prob := range result.Probabilities { + probSum += prob + } + + // Record probability distribution quality checks + if probSum >= 0.99 && probSum <= 1.01 { + metrics.RecordProbabilityDistributionQuality("sum_check", "valid") + } else { + metrics.RecordProbabilityDistributionQuality("sum_check", "invalid") + observability.Warnf("MCP probability distribution sum is %.3f (should be ~1.0)", probSum) + } + + // Check for negative probabilities + hasNegative := false + for _, prob := range result.Probabilities { + if prob < 0 { + hasNegative = true + break + } + } + + if hasNegative { + metrics.RecordProbabilityDistributionQuality("negative_check", "invalid") + } else { + metrics.RecordProbabilityDistributionQuality("negative_check", "valid") + } + + // Calculate uncertainty level from entropy value + entropyResult := entropy.AnalyzeEntropy(result.Probabilities) + uncertaintyLevel := entropyResult.UncertaintyLevel + + // Record comprehensive entropy classification metrics + metrics.RecordEntropyClassificationMetrics( + topCategory, + uncertaintyLevel, + entropyValue, + reasoningDecision.Confidence, + reasoningDecision.UseReasoning, + reasoningDecision.DecisionReason, + topCategory, + entropyLatency, + ) + + // Check confidence threshold for category determination + if result.Confidence < threshold { + observability.Infof("MCP classification confidence (%.4f) below threshold (%.4f), but entropy analysis available", + result.Confidence, threshold) + + // Still return reasoning decision based on entropy even if confidence is low + return "", float64(result.Confidence), reasoningDecision, nil + } + + // Map class index to category name + var categoryName string + var genericCategory string + if c.CategoryMapping != nil { + name, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class) + if ok { + categoryName = name + genericCategory = c.translateMMLUToGeneric(name) + } else { + categoryName = fmt.Sprintf("category_%d", result.Class) + genericCategory = categoryName + } + } else { + categoryName = fmt.Sprintf("category_%d", result.Class) + genericCategory = categoryName + } + + // Record the category classification metric + metrics.RecordCategoryClassification(genericCategory) + + observability.Infof("MCP classified as category: %s (mmlu=%s), reasoning_decision: use=%t, confidence=%.3f, reason=%s", + genericCategory, categoryName, reasoningDecision.UseReasoning, reasoningDecision.Confidence, reasoningDecision.DecisionReason) + + return genericCategory, float64(result.Confidence), reasoningDecision, nil +} + +// withMCPCategory creates an option function for MCP category classifier +func withMCPCategory(mcpInitializer MCPCategoryInitializer, mcpInference MCPCategoryInference) option { + return func(c *Classifier) { + c.mcpCategoryInitializer = mcpInitializer + c.mcpCategoryInference = mcpInference + } +} diff --git a/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go b/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go new file mode 100644 index 00000000..0d57f2f3 --- /dev/null +++ b/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go @@ -0,0 +1,779 @@ +package classification + +import ( + "context" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + mcpclient "github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp" +) + +// MockMCPClient is a mock implementation of the MCP client for testing +type MockMCPClient struct { + connectError error + callToolResult *mcp.CallToolResult + callToolError error + closeError error + connected bool + getToolsResult []mcp.Tool +} + +func (m *MockMCPClient) Connect() error { + if m.connectError != nil { + return m.connectError + } + m.connected = true + return nil +} + +func (m *MockMCPClient) Close() error { + if m.closeError != nil { + return m.closeError + } + m.connected = false + return nil +} + +func (m *MockMCPClient) IsConnected() bool { + return m.connected +} + +func (m *MockMCPClient) Ping(ctx context.Context) error { + return nil +} + +func (m *MockMCPClient) GetTools() []mcp.Tool { + return m.getToolsResult +} + +func (m *MockMCPClient) GetResources() []mcp.Resource { + return nil +} + +func (m *MockMCPClient) GetPrompts() []mcp.Prompt { + return nil +} + +func (m *MockMCPClient) RefreshCapabilities(ctx context.Context) error { + return nil +} + +func (m *MockMCPClient) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + if m.callToolError != nil { + return nil, m.callToolError + } + return m.callToolResult, nil +} + +func (m *MockMCPClient) ReadResource(ctx context.Context, uri string) (*mcp.ReadResourceResult, error) { + return nil, errors.New("not implemented") +} + +func (m *MockMCPClient) GetPrompt(ctx context.Context, name string, arguments map[string]interface{}) (*mcp.GetPromptResult, error) { + return nil, errors.New("not implemented") +} + +func (m *MockMCPClient) SetLogHandler(handler func(mcpclient.LoggingLevel, string)) { + // no-op for mock +} + +var _ mcpclient.MCPClient = (*MockMCPClient)(nil) + +var _ = Describe("MCP Category Classifier", func() { + var ( + mcpClassifier *MCPCategoryClassifier + mockClient *MockMCPClient + cfg *config.RouterConfig + ) + + BeforeEach(func() { + mockClient = &MockMCPClient{} + mcpClassifier = &MCPCategoryClassifier{} + cfg = &config.RouterConfig{} + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + cfg.Classifier.MCPCategoryModel.TransportType = "stdio" + cfg.Classifier.MCPCategoryModel.Command = "python" + cfg.Classifier.MCPCategoryModel.Args = []string{"server.py"} + cfg.Classifier.MCPCategoryModel.TimeoutSeconds = 30 + }) + + Describe("Init", func() { + Context("when config is nil", func() { + It("should return error", func() { + err := mcpClassifier.Init(nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("config is nil")) + }) + }) + + Context("when MCP is not enabled", func() { + It("should return error", func() { + cfg.Classifier.MCPCategoryModel.Enabled = false + err := mcpClassifier.Init(cfg) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not enabled")) + }) + }) + + // Note: tool_name is now optional and will be auto-discovered if not specified. + // The Init method will automatically discover classification tools from the MCP server + // by calling discoverClassificationTool(). + + // Note: Full initialization test requires mocking NewClient and GetTools which is complex + // In real tests, we'd need dependency injection for the client factory + }) + + Describe("discoverClassificationTool", func() { + BeforeEach(func() { + mcpClassifier.client = mockClient + mcpClassifier.config = cfg + }) + + Context("when tool name is explicitly configured", func() { + It("should use the configured tool name", func() { + cfg.Classifier.MCPCategoryModel.ToolName = "my_classifier" + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("my_classifier")) + }) + }) + + Context("when tool name is not configured", func() { + BeforeEach(func() { + cfg.Classifier.MCPCategoryModel.ToolName = "" + }) + + It("should discover classify_text tool", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "some_other_tool", Description: "Other tool"}, + {Name: "classify_text", Description: "Classifies text into categories"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("classify_text")) + }) + + It("should discover classify tool", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "classify", Description: "Classify text"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("classify")) + }) + + It("should discover categorize tool", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "categorize", Description: "Categorize text"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("categorize")) + }) + + It("should discover categorize_text tool", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "categorize_text", Description: "Categorize text into categories"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("categorize_text")) + }) + + It("should prioritize classify_text over other common names", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "categorize", Description: "Categorize"}, + {Name: "classify_text", Description: "Main classifier"}, + {Name: "classify", Description: "Classify"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("classify_text")) + }) + + It("should prefer common names over pattern matching", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "my_classification_tool", Description: "Custom classifier"}, + {Name: "classify", Description: "Built-in classifier"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("classify")) + }) + + It("should discover by pattern matching in name", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "text_classification", Description: "Some description"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("text_classification")) + }) + + It("should discover by pattern matching in description", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "analyze_text", Description: "Tool for text classification"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("analyze_text")) + }) + + It("should return error when no tools available", func() { + mockClient.getToolsResult = []mcp.Tool{} + err := mcpClassifier.discoverClassificationTool() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no tools available")) + }) + + It("should return error when no classification tool found", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "foo", Description: "Does foo"}, + {Name: "bar", Description: "Does bar"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no classification tool found")) + }) + + It("should handle case-insensitive pattern matching", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "TextClassification", Description: "Classify documents"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("TextClassification")) + }) + + It("should match 'classif' in description (case-insensitive)", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "my_tool", Description: "This tool performs Classification tasks"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).ToNot(HaveOccurred()) + Expect(mcpClassifier.toolName).To(Equal("my_tool")) + }) + + It("should log available tools when none match", func() { + mockClient.getToolsResult = []mcp.Tool{ + {Name: "tool1", Description: "Does something"}, + {Name: "tool2", Description: "Does another thing"}, + } + err := mcpClassifier.discoverClassificationTool() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tool1")) + Expect(err.Error()).To(ContainSubstring("tool2")) + }) + }) + + // Test suite summary: + // - Explicit configuration: ✓ (1 test) + // - Common tool names discovery: ✓ (4 tests - classify_text, classify, categorize, categorize_text) + // - Priority/precedence: ✓ (2 tests - classify_text first, common names over patterns) + // - Pattern matching: ✓ (4 tests - name, description, case-insensitive) + // - Error cases: ✓ (3 tests - no tools, no match, logging) + // Total: 14 comprehensive tests for auto-discovery + }) + + Describe("Close", func() { + Context("when client is nil", func() { + It("should not error", func() { + err := mcpClassifier.Close() + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("when client exists", func() { + BeforeEach(func() { + mcpClassifier.client = mockClient + }) + + It("should close the client successfully", func() { + err := mcpClassifier.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(mockClient.connected).To(BeFalse()) + }) + + It("should return error if close fails", func() { + mockClient.closeError = errors.New("close failed") + err := mcpClassifier.Close() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("close failed")) + }) + }) + }) + + Describe("Classify", func() { + BeforeEach(func() { + mcpClassifier.client = mockClient + mcpClassifier.toolName = "classify_text" + }) + + Context("when client is not initialized", func() { + It("should return error", func() { + mcpClassifier.client = nil + _, err := mcpClassifier.Classify(context.Background(), "test") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not initialized")) + }) + }) + + Context("when MCP tool call fails", func() { + It("should return error", func() { + mockClient.callToolError = errors.New("tool call failed") + _, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tool call failed")) + }) + }) + + Context("when MCP tool returns error result", func() { + It("should return error", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{mcp.TextContent{Type: "text", Text: "error message"}}, + } + _, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("returned error")) + }) + }) + + Context("when MCP tool returns empty content", func() { + It("should return error", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{}, + } + _, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("empty content")) + }) + }) + + Context("when MCP tool returns valid classification", func() { + It("should return classification result", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 2, "confidence": 0.95, "model": "openai/gpt-oss-20b", "use_reasoning": true}`, + }, + }, + } + result, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).ToNot(HaveOccurred()) + Expect(result.Class).To(Equal(2)) + Expect(result.Confidence).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when MCP tool returns classification with routing info", func() { + It("should parse model and use_reasoning fields", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 1, "confidence": 0.85, "model": "openai/gpt-oss-20b", "use_reasoning": false}`, + }, + }, + } + result, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).ToNot(HaveOccurred()) + Expect(result.Class).To(Equal(1)) + Expect(result.Confidence).To(BeNumerically("~", 0.85, 0.001)) + }) + }) + + Context("when MCP tool returns invalid JSON", func() { + It("should return error", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `invalid json`, + }, + }, + } + _, err := mcpClassifier.Classify(context.Background(), "test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to parse")) + }) + }) + }) + + Describe("ClassifyWithProbabilities", func() { + BeforeEach(func() { + mcpClassifier.client = mockClient + mcpClassifier.toolName = "classify_text" + }) + + Context("when client is not initialized", func() { + It("should return error", func() { + mcpClassifier.client = nil + _, err := mcpClassifier.ClassifyWithProbabilities(context.Background(), "test") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not initialized")) + }) + }) + + Context("when MCP tool returns valid result with probabilities", func() { + It("should return result with probability distribution", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 1, "confidence": 0.85, "probabilities": [0.10, 0.85, 0.05], "model": "openai/gpt-oss-20b", "use_reasoning": true}`, + }, + }, + } + result, err := mcpClassifier.ClassifyWithProbabilities(context.Background(), "test text") + Expect(err).ToNot(HaveOccurred()) + Expect(result.Class).To(Equal(1)) + Expect(result.Confidence).To(BeNumerically("~", 0.85, 0.001)) + Expect(result.Probabilities).To(HaveLen(3)) + Expect(result.Probabilities[0]).To(BeNumerically("~", 0.10, 0.001)) + Expect(result.Probabilities[1]).To(BeNumerically("~", 0.85, 0.001)) + Expect(result.Probabilities[2]).To(BeNumerically("~", 0.05, 0.001)) + }) + }) + }) + + Describe("ListCategories", func() { + BeforeEach(func() { + mcpClassifier.client = mockClient + }) + + Context("when client is not initialized", func() { + It("should return error", func() { + mcpClassifier.client = nil + _, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not initialized")) + }) + }) + + Context("when MCP tool returns valid categories", func() { + It("should return category mapping", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"categories": ["math", "science", "technology", "history", "general"]}`, + }, + }, + } + mapping, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(mapping).ToNot(BeNil()) + Expect(mapping.CategoryToIdx).To(HaveLen(5)) + Expect(mapping.CategoryToIdx["math"]).To(Equal(0)) + Expect(mapping.CategoryToIdx["science"]).To(Equal(1)) + Expect(mapping.CategoryToIdx["technology"]).To(Equal(2)) + Expect(mapping.CategoryToIdx["history"]).To(Equal(3)) + Expect(mapping.CategoryToIdx["general"]).To(Equal(4)) + Expect(mapping.IdxToCategory["0"]).To(Equal("math")) + Expect(mapping.IdxToCategory["4"]).To(Equal("general")) + }) + }) + + Context("when MCP tool returns error", func() { + It("should return error", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{mcp.TextContent{Type: "text", Text: "error loading categories"}}, + } + _, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("returned error")) + }) + }) + + Context("when MCP tool returns invalid JSON", func() { + It("should return error", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `invalid json`, + }, + }, + } + _, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to parse")) + }) + }) + + Context("when MCP tool returns empty categories", func() { + It("should return empty mapping", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"categories": []}`, + }, + }, + } + mapping, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(mapping).ToNot(BeNil()) + Expect(mapping.CategoryToIdx).To(HaveLen(0)) + Expect(mapping.IdxToCategory).To(HaveLen(0)) + }) + }) + }) +}) + +var _ = Describe("Classifier MCP Methods", func() { + var ( + classifier *Classifier + mockClient *MockMCPClient + ) + + BeforeEach(func() { + mockClient = &MockMCPClient{} + cfg := &config.RouterConfig{} + cfg.Classifier.MCPCategoryModel.Enabled = true + cfg.Classifier.MCPCategoryModel.ToolName = "classify_text" + cfg.Classifier.MCPCategoryModel.Threshold = 0.5 + cfg.Classifier.MCPCategoryModel.TimeoutSeconds = 30 + + // Create MCP classifier manually and inject mock client + mcpClassifier := &MCPCategoryClassifier{ + client: mockClient, + toolName: "classify_text", + config: cfg, + } + + classifier = &Classifier{ + Config: cfg, + mcpCategoryInitializer: mcpClassifier, + mcpCategoryInference: mcpClassifier, + CategoryMapping: &CategoryMapping{ + CategoryToIdx: map[string]int{"tech": 0, "sports": 1, "politics": 2}, + IdxToCategory: map[string]string{"0": "tech", "1": "sports", "2": "politics"}, + }, + } + }) + + Describe("IsMCPCategoryEnabled", func() { + It("should return true when properly configured", func() { + Expect(classifier.IsMCPCategoryEnabled()).To(BeTrue()) + }) + + It("should return false when not enabled", func() { + classifier.Config.Classifier.MCPCategoryModel.Enabled = false + Expect(classifier.IsMCPCategoryEnabled()).To(BeFalse()) + }) + + // Note: tool_name is now optional and will be auto-discovered if not specified. + // IsMCPCategoryEnabled only checks if MCP is enabled, not specific configuration details. + // Runtime checks (like initializer != nil or successful connection) are handled + // separately in the actual initialization and classification methods. + }) + + Describe("classifyCategoryMCP", func() { + Context("when MCP is not enabled", func() { + It("should return error", func() { + classifier.Config.Classifier.MCPCategoryModel.Enabled = false + _, _, err := classifier.classifyCategoryMCP("test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not properly configured")) + }) + }) + + Context("when classification succeeds with high confidence", func() { + It("should return category name", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 2, "confidence": 0.95, "model": "openai/gpt-oss-20b", "use_reasoning": true}`, + }, + }, + } + + category, confidence, err := classifier.classifyCategoryMCP("test text") + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("politics")) + Expect(confidence).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when confidence is below threshold", func() { + It("should return empty category", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 1, "confidence": 0.3, "model": "openai/gpt-oss-20b", "use_reasoning": false}`, + }, + }, + } + + category, confidence, err := classifier.classifyCategoryMCP("test text") + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.3, 0.001)) + }) + }) + + Context("when class index is not in mapping", func() { + It("should return generic category name", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 99, "confidence": 0.85, "model": "openai/gpt-oss-20b", "use_reasoning": true}`, + }, + }, + } + + category, confidence, err := classifier.classifyCategoryMCP("test text") + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("category_99")) + Expect(confidence).To(BeNumerically("~", 0.85, 0.001)) + }) + }) + + Context("when MCP call fails", func() { + It("should return error", func() { + mockClient.callToolError = errors.New("network error") + + _, _, err := classifier.classifyCategoryMCP("test text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("MCP tool call failed")) + }) + }) + }) + + Describe("classifyCategoryWithEntropyMCP", func() { + BeforeEach(func() { + classifier.Config.Categories = []config.Category{ + {Name: "tech", ModelScores: []config.ModelScore{{Model: "phi4", Score: 0.8, UseReasoning: config.BoolPtr(false)}}}, + {Name: "sports", ModelScores: []config.ModelScore{{Model: "phi4", Score: 0.8, UseReasoning: config.BoolPtr(false)}}}, + {Name: "politics", ModelScores: []config.ModelScore{{Model: "deepseek-v31", Score: 0.9, UseReasoning: config.BoolPtr(true)}}}, + } + }) + + Context("when MCP returns probabilities", func() { + It("should return category with entropy decision", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 2, "confidence": 0.95, "probabilities": [0.02, 0.03, 0.95], "model": "openai/gpt-oss-20b", "use_reasoning": true}`, + }, + }, + } + + category, confidence, reasoningDecision, err := classifier.classifyCategoryWithEntropyMCP("test text") + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("politics")) + Expect(confidence).To(BeNumerically("~", 0.95, 0.001)) + Expect(len(reasoningDecision.TopCategories)).To(BeNumerically(">", 0)) + }) + }) + + Context("when confidence is below threshold", func() { + It("should return empty category but provide entropy decision", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"class": 0, "confidence": 0.3, "probabilities": [0.3, 0.35, 0.35], "model": "openai/gpt-oss-20b", "use_reasoning": false}`, + }, + }, + } + + category, confidence, reasoningDecision, err := classifier.classifyCategoryWithEntropyMCP("test text") + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.3, 0.001)) + Expect(len(reasoningDecision.TopCategories)).To(BeNumerically(">", 0)) + }) + }) + }) + + Describe("initializeMCPCategoryClassifier", func() { + Context("when MCP is not enabled", func() { + It("should return error", func() { + classifier.Config.Classifier.MCPCategoryModel.Enabled = false + err := classifier.initializeMCPCategoryClassifier() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not properly configured")) + }) + }) + + Context("when initializer is nil", func() { + It("should return error", func() { + classifier.mcpCategoryInitializer = nil + err := classifier.initializeMCPCategoryClassifier() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("initializer is not set")) + }) + }) + }) +}) + +var _ = Describe("MCP Helper Functions", func() { + Describe("createMCPCategoryInitializer", func() { + It("should create MCPCategoryClassifier", func() { + initializer := createMCPCategoryInitializer() + Expect(initializer).ToNot(BeNil()) + _, ok := initializer.(*MCPCategoryClassifier) + Expect(ok).To(BeTrue()) + }) + }) + + Describe("createMCPCategoryInference", func() { + It("should create inference from initializer", func() { + initializer := &MCPCategoryClassifier{} + inference := createMCPCategoryInference(initializer) + Expect(inference).ToNot(BeNil()) + Expect(inference).To(Equal(initializer)) + }) + + It("should return nil for non-MCP initializer", func() { + type FakeInitializer struct{} + fakeInit := struct { + FakeInitializer + MCPCategoryInitializer + }{} + inference := createMCPCategoryInference(&fakeInit) + Expect(inference).To(BeNil()) + }) + }) + + Describe("withMCPCategory", func() { + It("should set MCP fields on classifier", func() { + classifier := &Classifier{} + initializer := &MCPCategoryClassifier{} + inference := createMCPCategoryInference(initializer) + + option := withMCPCategory(initializer, inference) + option(classifier) + + Expect(classifier.mcpCategoryInitializer).To(Equal(initializer)) + Expect(classifier.mcpCategoryInference).To(Equal(inference)) + }) + }) +})