Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions config/config-mcp-classifier-example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Example Configuration for MCP-Based Category Classifier
#
# This configuration demonstrates how to use an external MCP (Model Context Protocol)
# service for category classification instead of the built-in Candle/ModernBERT models.
#
# Use cases:
# - Offload classification to a remote service
# - Use custom classification models not supported in-tree
# - Scale classification independently from the router
# - Integrate with existing ML infrastructure

# BERT model for semantic caching and tool selection
bert_model:
model_id: "models/all-MiniLM-L6-v2"
threshold: 0.85
use_cpu: true

# Classifier configuration
classifier:
# Disable in-tree category classifier (leave model_id empty)
category_model:
model_id: "" # Empty = disabled
threshold: 0.6
use_cpu: true
use_modernbert: false
category_mapping_path: ""

# Enable MCP-based category classifier
mcp_category_model:
enabled: true # Enable MCP classifier
transport_type: "stdio" # "stdio" or "http"

# For stdio transport: run a local Python MCP server
command: "python"
args: ["-m", "mcp_category_classifier"]
env:
PYTHONPATH: "/opt/ml/models"
MODEL_PATH: "/opt/ml/models/category_classifier"
LOG_LEVEL: "INFO"

# For http transport: use this instead
# transport_type: "http"
# url: "http://localhost:8080/mcp"

tool_name: "classify_text" # MCP tool name to call
threshold: 0.6 # Confidence threshold
timeout_seconds: 30 # Request timeout

# PII model configuration (unchanged)
pii_model:
model_id: "models/pii_classifier"
threshold: 0.7
use_cpu: true
pii_mapping_path: "models/pii_classifier/pii_type_mapping.json"

# Prompt guard configuration (unchanged)
prompt_guard:
enabled: true
model_id: "models/jailbreak_classifier"
threshold: 0.8
use_cpu: true
use_modernbert: true
jailbreak_mapping_path: "models/jailbreak_classifier/jailbreak_mapping.json"

# Categories for routing queries
categories:
- name: "math"
description: "Mathematical problems, equations, calculus, algebra, statistics"
model_scores:
- model: "deepseek/deepseek-r1:70b"
score: 0.95
use_reasoning: true
- model: "qwen/qwen3-235b"
score: 0.90
use_reasoning: true
mmlu_categories:
- "mathematics"
- "statistics"

- name: "coding"
description: "Programming, software development, debugging, algorithms"
model_scores:
- model: "deepseek/deepseek-r1-coder:33b"
score: 0.95
use_reasoning: true
- model: "meta/llama3.1-70b"
score: 0.85
use_reasoning: false
mmlu_categories:
- "computer_science"
- "engineering"

- name: "general"
description: "General knowledge, conversation, misc queries"
model_scores:
- model: "meta/llama3.1-70b"
score: 0.90
use_reasoning: false
- model: "qwen/qwen3-235b"
score: 0.85
use_reasoning: false

# Default model to use when category can't be determined
default_model: "meta/llama3.1-70b"

# vLLM endpoints configuration
vllm_endpoints:
- name: "deepseek-endpoint"
address: "10.0.1.10"
port: 8000
models:
- "deepseek/deepseek-r1:70b"
- "deepseek/deepseek-r1-coder:33b"
weight: 100

- name: "qwen-endpoint"
address: "10.0.1.11"
port: 8000
models:
- "qwen/qwen3-235b"
weight: 100

- name: "llama-endpoint"
address: "10.0.1.12"
port: 8000
models:
- "meta/llama3.1-70b"
weight: 100

# Semantic cache configuration (optional)
semantic_cache:
enabled: true
backend_type: "in-memory"
similarity_threshold: 0.90
max_entries: 1000
ttl_seconds: 3600
eviction_policy: "lru"

# Model-specific configuration
model_config:
"deepseek/deepseek-r1:70b":
reasoning_family: "deepseek"
pii_policy:
allow_by_default: false
pii_types_allowed: []

"deepseek/deepseek-r1-coder:33b":
reasoning_family: "deepseek"
pii_policy:
allow_by_default: false
pii_types_allowed: []

"qwen/qwen3-235b":
reasoning_family: "qwen3"
pii_policy:
allow_by_default: true

"meta/llama3.1-70b":
pii_policy:
allow_by_default: true

# Reasoning family configurations
reasoning_families:
deepseek:
type: "chat_template_kwargs"
parameter: "thinking"
qwen3:
type: "reasoning_effort"
parameter: "reasoning_effort"
gpt-oss:
type: "chat_template_kwargs"
parameter: "enable_thinking"

# Tools configuration (optional)
tools:
enabled: false
top_k: 5
similarity_threshold: 0.7
tools_db_path: "config/tools_db.json"
fallback_to_empty: true

# API configuration
api:
batch_classification:
metrics:
enabled: true
sample_rate: 1.0

# Observability configuration
observability:
tracing:
enabled: false
provider: "opentelemetry"
exporter:
type: "otlp"
endpoint: "localhost:4317"
insecure: true
sampling:
type: "always_on"
resource:
service_name: "semantic-router"
service_version: "1.0.0"
deployment_environment: "production"

8 changes: 8 additions & 0 deletions src/semantic-router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ replace (
require (
github.com/envoyproxy/go-control-plane/envoy v1.32.4
github.com/fsnotify/fsnotify v1.7.0
github.com/mark3labs/mcp-go v0.42.0-beta.1
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2
github.com/onsi/ginkgo/v2 v2.23.4
github.com/onsi/gomega v1.38.0
Expand All @@ -34,7 +35,9 @@ require (
)

require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect
Expand All @@ -56,10 +59,12 @@ require (
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
Expand All @@ -70,11 +75,14 @@ require (
github.com/prometheus/common v0.65.0 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
Expand Down
19 changes: 19 additions & 0 deletions src/semantic-router/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqR
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
Expand Down Expand Up @@ -53,6 +57,8 @@ github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2T
github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw=
github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
Expand Down Expand Up @@ -137,11 +143,14 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/hydrogen18/memlistener v0.0.0-20200120041712-dcc25e7acd91/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE=
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI=
github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0=
github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk=
github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g=
github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down Expand Up @@ -175,6 +184,10 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+
github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mark3labs/mcp-go v0.42.0-beta.1 h1:jXCUOg7vHwSuknzy4hPvOXASnzmLluM3AMx1rPh/OYM=
github.com/mark3labs/mcp-go v0.42.0-beta.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
Expand Down Expand Up @@ -254,6 +267,8 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
Expand Down Expand Up @@ -291,13 +306,17 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc=
Expand Down
16 changes: 16 additions & 0 deletions src/semantic-router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ type RouterConfig struct {
UseModernBERT bool `yaml:"use_modernbert"`
CategoryMappingPath string `yaml:"category_mapping_path"`
} `yaml:"category_model"`
MCPCategoryModel struct {
Enabled bool `yaml:"enabled"`
TransportType string `yaml:"transport_type"`
Command string `yaml:"command,omitempty"`
Args []string `yaml:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty"`
URL string `yaml:"url,omitempty"`
ToolName string `yaml:"tool_name"`
Threshold float32 `yaml:"threshold"`
TimeoutSeconds int `yaml:"timeout_seconds,omitempty"`
} `yaml:"mcp_category_model,omitempty"`
PIIModel struct {
ModelID string `yaml:"model_id"`
Threshold float32 `yaml:"threshold"`
Expand Down Expand Up @@ -574,6 +585,11 @@ func (c *RouterConfig) IsCategoryClassifierEnabled() bool {
return c.Classifier.CategoryModel.ModelID != "" && c.Classifier.CategoryModel.CategoryMappingPath != ""
}

// IsMCPCategoryClassifierEnabled checks if MCP-based category classification is enabled
func (c *RouterConfig) IsMCPCategoryClassifierEnabled() bool {
return c.Classifier.MCPCategoryModel.Enabled && c.Classifier.MCPCategoryModel.ToolName != ""
}

// GetPromptGuardConfig returns the prompt guard configuration
func (c *RouterConfig) GetPromptGuardConfig() PromptGuardConfig {
return c.PromptGuard
Expand Down
Loading
Loading