Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/mcp-classifier-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This directory contains **two MCP classification servers**:
### 2. **Embedding-Based Server** (`server_embedding.py`) 🆕

- ✅ **High Accuracy** - Semantic understanding with Qwen3-Embedding-0.6B
- ✅ **RAG-Style** - FAISS vector database with similarity search
- ✅ **RAG-Style** - Milvus vector database with similarity search
- ✅ **Flexible** - Handles paraphrases, synonyms, variations
- 📝 **Best For**: Production use, high-accuracy requirements

Expand Down Expand Up @@ -210,7 +210,7 @@ python3 server_embedding.py --http --port 8090
### Features

- **Qwen3-Embedding-0.6B** model with 1024-dimensional embeddings
- **FAISS vector database** for fast similarity search
- **Milvus vector database** for fast similarity search
- **RAG-style classification** using 95 training examples
- **Same MCP protocol** as regex server (drop-in replacement)
- **Higher accuracy** - Understands semantic meaning, not just patterns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mcp>=1.0.0
# Embedding and Vector Search
transformers>=4.30.0
torch>=2.0.0
faiss-cpu>=1.7.4 # Use faiss-gpu if you have GPU support
pymilvus>=2.5.0

# HTTP server support (optional, for HTTP mode)
aiohttp>=3.9.0
Expand Down
191 changes: 138 additions & 53 deletions examples/mcp-classifier-server/server_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
1. Text classification using semantic embeddings (RAG-style)
2. Dynamic category discovery via list_categories
3. Intelligent routing decisions (model selection and reasoning control)
4. FAISS vector database for similarity search
4. Milvus vector database for similarity search

The server implements two MCP tools:
- 'list_categories': Returns available categories with per-category system prompts and descriptions
Expand Down Expand Up @@ -38,6 +38,10 @@

# HTTP mode (for semantic router)
python server_embedding.py --http --port 8090

Prerequisites:
- pip install pymilvus (includes Milvus Lite - no Docker needed!)
- pip install torch transformers mcp
"""

import argparse
Expand All @@ -49,12 +53,12 @@
from pathlib import Path
from typing import Any

import faiss
import numpy as np
import torch
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from pymilvus import MilvusClient
from transformers import AutoModel, AutoTokenizer

# Configure logging
Expand Down Expand Up @@ -114,27 +118,30 @@


class EmbeddingClassifier:
"""Embedding-based text classifier using FAISS vector search."""
"""Embedding-based text classifier using Milvus vector search."""

def __init__(
self,
model_name: str = "Qwen/Qwen3-Embedding-0.6B",
csv_path: str = "training_data.csv",
index_path: str = "faiss_index.bin",
collection_name: str = "embedding_classifier",
device: str = "auto",
milvus_uri: str = "./milvus_data.db",
):
"""
Initialize the embedding classifier.

Args:
model_name: Name of the embedding model to use
csv_path: Path to the CSV training data file
index_path: Path to save/load FAISS index
collection_name: Name of the Milvus collection
device: Device to use ("cuda", "cpu", or "auto" for auto-detection)
milvus_uri: Milvus Lite database file path (default: "./milvus_data.db")
"""
self.model_name = model_name
self.csv_path = csv_path
self.index_path = index_path
self.collection_name = collection_name
self.milvus_uri = milvus_uri

logger.info(f"Initializing embedding model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -161,7 +168,7 @@ def __init__(
# Qwen3-Embedding-0.6B has embedding dimension of 1024
self.embedding_dim = 1024

self.index = None
self.client = None
self.category_names = list(CATEGORY_CONFIG.keys())
self.category_to_index = {
name: idx for idx, name in enumerate(self.category_names)
Expand All @@ -172,21 +179,9 @@ def __init__(
self.texts, self.categories = self._load_csv_data()
logger.info(f"Loaded {len(self.texts)} training examples")

# Load or build FAISS index
if os.path.exists(index_path):
logger.info(f"Loading existing FAISS index from {index_path}")
self.index = faiss.read_index(index_path)
logger.info(f"Index loaded with {self.index.ntotal} vectors")

# Verify index size matches CSV
if self.index.ntotal != len(self.texts):
logger.warning(
f"Index size ({self.index.ntotal}) doesn't match CSV ({len(self.texts)}). Rebuilding..."
)
self._build_index()
else:
logger.info("No existing index found, building new one...")
self._build_index()
# Connect to Milvus Lite and initialize collection
self._connect_milvus()
self._init_collection()

def _encode_texts(self, texts: list[str], batch_size: int = 8) -> np.ndarray:
"""
Expand Down Expand Up @@ -244,27 +239,86 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]:
logger.info(f"Loaded {len(texts)} training examples")
return texts, categories

def _build_index(self):
"""Build FAISS index from loaded CSV data."""
logger.info("Building FAISS index from training data...")
def _connect_milvus(self):
"""Connect to Milvus Lite."""
try:
logger.info(f"Connecting to Milvus Lite at {self.milvus_uri}")
self.client = MilvusClient(self.milvus_uri)
logger.info("Successfully connected to Milvus Lite")
except Exception as e:
logger.error(f"Failed to connect to Milvus Lite: {e}")
raise

def _init_collection(self):
"""Initialize or load Milvus collection."""
# Check if collection exists
if self.client.has_collection(self.collection_name):
logger.info(f"Loading existing collection: {self.collection_name}")

# Check if we need to rebuild (verify count matches CSV)
stats = self.client.get_collection_stats(self.collection_name)
current_count = stats.get("row_count", 0)
expected_count = len(self.texts)

if current_count != expected_count:
logger.warning(
f"Collection has {current_count} entities but CSV has {expected_count}. Rebuilding..."
)
self.client.drop_collection(self.collection_name)
self._create_and_build_collection()
else:
logger.info(f"Collection loaded with {current_count} vectors")
else:
logger.info(
f"Collection {self.collection_name} not found, creating new one..."
)
self._create_and_build_collection()

def _create_and_build_collection(self):
"""Create Milvus collection and insert embeddings."""
logger.info("Creating new Milvus collection...")

# Create collection with schema
self.client.create_collection(
collection_name=self.collection_name,
dimension=self.embedding_dim,
metric_type="IP", # Inner Product for cosine similarity
auto_id=False,
primary_field_name="id",
vector_field_name="embedding",
)
logger.info(f"Collection {self.collection_name} created")

# Generate embeddings
logger.info(f"Generating embeddings for {len(self.texts)} examples...")
embeddings = self._encode_texts(self.texts)

# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings.astype("float32"))

# Build FAISS index
logger.info(f"Creating FAISS index with dimension {self.embedding_dim}...")
self.index = faiss.IndexFlatIP(self.embedding_dim) # Inner product for cosine
self.index.add(embeddings)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / norms

# Prepare data for insertion
data = []
for i, (text, category, embedding) in enumerate(
zip(self.texts, self.categories, embeddings)
):
data.append(
{
"id": i,
"embedding": embedding.tolist(),
"category": category,
"text": text,
}
)

# Save index
logger.info(f"Saving index to {self.index_path}")
faiss.write_index(self.index, self.index_path)
# Insert data in batches
batch_size = 100
logger.info("Inserting data into Milvus...")
for i in range(0, len(data), batch_size):
batch = data[i : i + batch_size]
self.client.insert(collection_name=self.collection_name, data=batch)

logger.info(f"Index built successfully with {self.index.ntotal} vectors")
logger.info(f"Collection built successfully with {len(data)} vectors")

def classify(
self, text: str, k: int = 20, with_probabilities: bool = False
Expand All @@ -282,14 +336,27 @@ def classify(
"""
# Generate embedding for query text
query_embedding = self._encode_texts([text])
faiss.normalize_L2(query_embedding.astype("float32"))

# Search for k nearest neighbors
similarities, indices = self.index.search(query_embedding, k)
# Normalize embeddings for cosine similarity
norms = np.linalg.norm(query_embedding, axis=1, keepdims=True)
query_embedding = query_embedding / norms

# Search for k nearest neighbors using Milvus
results = self.client.search(
collection_name=self.collection_name,
data=query_embedding.tolist(),
limit=k,
output_fields=["category", "text"],
)

# Extract results
neighbor_categories = []
neighbor_similarities = []

# Get categories of nearest neighbors
neighbor_categories = [self.categories[idx] for idx in indices[0]]
neighbor_similarities = similarities[0]
for hits in results:
for hit in hits:
neighbor_categories.append(hit.get("entity", {}).get("category"))
neighbor_similarities.append(hit.get("distance", 0))

# Calculate confidence scores for each category
category_scores = {cat: 0.0 for cat in self.category_names}
Expand Down Expand Up @@ -417,6 +484,7 @@ def _decide_routing(
# For multi-process deployments, each process gets its own instance.
classifier = None
classifier_device = "auto" # Default device setting
classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path


def get_classifier():
Expand All @@ -426,13 +494,14 @@ def get_classifier():
# Get script directory
script_dir = Path(__file__).parent
csv_path = script_dir / "training_data.csv"
index_path = script_dir / "faiss_index.bin"
milvus_uri = script_dir / classifier_milvus_uri

classifier = EmbeddingClassifier(
model_name="Qwen/Qwen3-Embedding-0.6B",
csv_path=str(csv_path),
index_path=str(index_path),
collection_name="embedding_classifier",
device=classifier_device,
milvus_uri=str(milvus_uri),
)
return classifier

Expand Down Expand Up @@ -529,26 +598,32 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
]


async def main_stdio(device: str = "auto"):
async def main_stdio(device: str = "auto", milvus_uri: str = "./milvus_data.db"):
"""Run the MCP server in stdio mode."""
global classifier_device
global classifier_device, classifier_milvus_uri
classifier_device = device
classifier_milvus_uri = milvus_uri

logger.info("Starting Embedding-Based MCP Classification Server (stdio mode)")
clf = get_classifier()
logger.info(f"Available categories: {', '.join(clf.category_names)}")
logger.info(f"Model: {clf.model_name}")
logger.info(f"Device: {clf.device}")
logger.info(f"Index size: {clf.index.ntotal} vectors")
logger.info(f"Milvus Lite: {clf.milvus_uri}")
stats = clf.client.get_collection_stats(clf.collection_name)
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")

async with stdio_server() as (read_stream, write_stream):
await app.run(read_stream, write_stream, app.create_initialization_options())


async def main_http(port: int = 8091, device: str = "auto"):
async def main_http(
port: int = 8091, device: str = "auto", milvus_uri: str = "./milvus_data.db"
):
"""Run the MCP server in HTTP mode."""
global classifier_device
global classifier_device, classifier_milvus_uri
classifier_device = device
classifier_milvus_uri = milvus_uri

try:
from aiohttp import web
Expand All @@ -563,7 +638,9 @@ async def main_http(port: int = 8091, device: str = "auto"):
logger.info(f"Available categories: {', '.join(clf.category_names)}")
logger.info(f"Model: {clf.model_name}")
logger.info(f"Device: {clf.device}")
logger.info(f"Index size: {clf.index.ntotal} vectors")
logger.info(f"Milvus Lite: {clf.milvus_uri}")
stats = clf.client.get_collection_stats(clf.collection_name)
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")
logger.info(f"Listening on http://0.0.0.0:{port}/mcp")

async def handle_mcp_request(request):
Expand Down Expand Up @@ -672,12 +749,14 @@ async def handle_mcp_request(request):
async def health_check(request):
"""Health check endpoint."""
clf = get_classifier()
stats = clf.client.get_collection_stats(clf.collection_name)
return web.json_response(
{
"status": "ok",
"categories": clf.category_names,
"model": clf.model_name,
"index_size": clf.index.ntotal,
"collection_size": stats.get("row_count", 0),
"milvus_uri": clf.milvus_uri,
}
)

Expand Down Expand Up @@ -724,7 +803,7 @@ async def health_check(request):

# Parse command line arguments
parser = argparse.ArgumentParser(
description="MCP Embedding-Based Classification Server"
description="MCP Embedding-Based Classification Server (Milvus Lite)"
)
parser.add_argument(
"--http", action="store_true", help="Run in HTTP mode instead of stdio"
Expand All @@ -737,9 +816,15 @@ async def health_check(request):
choices=["auto", "cuda", "cpu"],
help="Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)",
)
parser.add_argument(
"--milvus-uri",
type=str,
default="./milvus_data.db",
help="Milvus Lite database file path (default: ./milvus_data.db)",
)
args = parser.parse_args()

if args.http:
asyncio.run(main_http(args.port, args.device))
asyncio.run(main_http(args.port, args.device, args.milvus_uri))
else:
asyncio.run(main_stdio(args.device))
asyncio.run(main_stdio(args.device, args.milvus_uri))
Loading