From c9bcbce70f912907a28c56914a0224342ab324de Mon Sep 17 00:00:00 2001 From: JackLCL Date: Wed, 15 Oct 2025 23:29:16 +0800 Subject: [PATCH] use milvus vector database for mcp-classifier-server in examples Signed-off-by: JackLCL --- examples/mcp-classifier-server/README.md | 4 +- .../requirements_embedding.txt | 2 +- .../mcp-classifier-server/server_embedding.py | 191 +++++++++++++----- 3 files changed, 141 insertions(+), 56 deletions(-) diff --git a/examples/mcp-classifier-server/README.md b/examples/mcp-classifier-server/README.md index 4fef7f18..51901f2f 100644 --- a/examples/mcp-classifier-server/README.md +++ b/examples/mcp-classifier-server/README.md @@ -16,7 +16,7 @@ This directory contains **two MCP classification servers**: ### 2. **Embedding-Based Server** (`server_embedding.py`) 🆕 - ✅ **High Accuracy** - Semantic understanding with Qwen3-Embedding-0.6B -- ✅ **RAG-Style** - FAISS vector database with similarity search +- ✅ **RAG-Style** - Milvus vector database with similarity search - ✅ **Flexible** - Handles paraphrases, synonyms, variations - 📝 **Best For**: Production use, high-accuracy requirements @@ -210,7 +210,7 @@ python3 server_embedding.py --http --port 8090 ### Features - **Qwen3-Embedding-0.6B** model with 1024-dimensional embeddings -- **FAISS vector database** for fast similarity search +- **Milvus vector database** for fast similarity search - **RAG-style classification** using 95 training examples - **Same MCP protocol** as regex server (drop-in replacement) - **Higher accuracy** - Understands semantic meaning, not just patterns diff --git a/examples/mcp-classifier-server/requirements_embedding.txt b/examples/mcp-classifier-server/requirements_embedding.txt index c401dfe8..8c286d46 100644 --- a/examples/mcp-classifier-server/requirements_embedding.txt +++ b/examples/mcp-classifier-server/requirements_embedding.txt @@ -6,7 +6,7 @@ mcp>=1.0.0 # Embedding and Vector Search transformers>=4.30.0 torch>=2.0.0 -faiss-cpu>=1.7.4 # Use faiss-gpu if you have GPU support +pymilvus>=2.5.0 # HTTP server support (optional, for HTTP mode) aiohttp>=3.9.0 diff --git a/examples/mcp-classifier-server/server_embedding.py b/examples/mcp-classifier-server/server_embedding.py index 2090a4c8..19133318 100644 --- a/examples/mcp-classifier-server/server_embedding.py +++ b/examples/mcp-classifier-server/server_embedding.py @@ -6,7 +6,7 @@ 1. Text classification using semantic embeddings (RAG-style) 2. Dynamic category discovery via list_categories 3. Intelligent routing decisions (model selection and reasoning control) -4. FAISS vector database for similarity search +4. Milvus vector database for similarity search The server implements two MCP tools: - 'list_categories': Returns available categories with per-category system prompts and descriptions @@ -38,6 +38,10 @@ # HTTP mode (for semantic router) python server_embedding.py --http --port 8090 + +Prerequisites: + - pip install pymilvus (includes Milvus Lite - no Docker needed!) + - pip install torch transformers mcp """ import argparse @@ -49,12 +53,12 @@ from pathlib import Path from typing import Any -import faiss import numpy as np import torch from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool +from pymilvus import MilvusClient from transformers import AutoModel, AutoTokenizer # Configure logging @@ -114,14 +118,15 @@ class EmbeddingClassifier: - """Embedding-based text classifier using FAISS vector search.""" + """Embedding-based text classifier using Milvus vector search.""" def __init__( self, model_name: str = "Qwen/Qwen3-Embedding-0.6B", csv_path: str = "training_data.csv", - index_path: str = "faiss_index.bin", + collection_name: str = "embedding_classifier", device: str = "auto", + milvus_uri: str = "./milvus_data.db", ): """ Initialize the embedding classifier. @@ -129,12 +134,14 @@ def __init__( Args: model_name: Name of the embedding model to use csv_path: Path to the CSV training data file - index_path: Path to save/load FAISS index + collection_name: Name of the Milvus collection device: Device to use ("cuda", "cpu", or "auto" for auto-detection) + milvus_uri: Milvus Lite database file path (default: "./milvus_data.db") """ self.model_name = model_name self.csv_path = csv_path - self.index_path = index_path + self.collection_name = collection_name + self.milvus_uri = milvus_uri logger.info(f"Initializing embedding model: {model_name}") self.tokenizer = AutoTokenizer.from_pretrained( @@ -161,7 +168,7 @@ def __init__( # Qwen3-Embedding-0.6B has embedding dimension of 1024 self.embedding_dim = 1024 - self.index = None + self.client = None self.category_names = list(CATEGORY_CONFIG.keys()) self.category_to_index = { name: idx for idx, name in enumerate(self.category_names) @@ -172,21 +179,9 @@ def __init__( self.texts, self.categories = self._load_csv_data() logger.info(f"Loaded {len(self.texts)} training examples") - # Load or build FAISS index - if os.path.exists(index_path): - logger.info(f"Loading existing FAISS index from {index_path}") - self.index = faiss.read_index(index_path) - logger.info(f"Index loaded with {self.index.ntotal} vectors") - - # Verify index size matches CSV - if self.index.ntotal != len(self.texts): - logger.warning( - f"Index size ({self.index.ntotal}) doesn't match CSV ({len(self.texts)}). Rebuilding..." - ) - self._build_index() - else: - logger.info("No existing index found, building new one...") - self._build_index() + # Connect to Milvus Lite and initialize collection + self._connect_milvus() + self._init_collection() def _encode_texts(self, texts: list[str], batch_size: int = 8) -> np.ndarray: """ @@ -244,27 +239,86 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]: logger.info(f"Loaded {len(texts)} training examples") return texts, categories - def _build_index(self): - """Build FAISS index from loaded CSV data.""" - logger.info("Building FAISS index from training data...") + def _connect_milvus(self): + """Connect to Milvus Lite.""" + try: + logger.info(f"Connecting to Milvus Lite at {self.milvus_uri}") + self.client = MilvusClient(self.milvus_uri) + logger.info("Successfully connected to Milvus Lite") + except Exception as e: + logger.error(f"Failed to connect to Milvus Lite: {e}") + raise + + def _init_collection(self): + """Initialize or load Milvus collection.""" + # Check if collection exists + if self.client.has_collection(self.collection_name): + logger.info(f"Loading existing collection: {self.collection_name}") + + # Check if we need to rebuild (verify count matches CSV) + stats = self.client.get_collection_stats(self.collection_name) + current_count = stats.get("row_count", 0) + expected_count = len(self.texts) + + if current_count != expected_count: + logger.warning( + f"Collection has {current_count} entities but CSV has {expected_count}. Rebuilding..." + ) + self.client.drop_collection(self.collection_name) + self._create_and_build_collection() + else: + logger.info(f"Collection loaded with {current_count} vectors") + else: + logger.info( + f"Collection {self.collection_name} not found, creating new one..." + ) + self._create_and_build_collection() + + def _create_and_build_collection(self): + """Create Milvus collection and insert embeddings.""" + logger.info("Creating new Milvus collection...") + + # Create collection with schema + self.client.create_collection( + collection_name=self.collection_name, + dimension=self.embedding_dim, + metric_type="IP", # Inner Product for cosine similarity + auto_id=False, + primary_field_name="id", + vector_field_name="embedding", + ) + logger.info(f"Collection {self.collection_name} created") # Generate embeddings logger.info(f"Generating embeddings for {len(self.texts)} examples...") embeddings = self._encode_texts(self.texts) # Normalize embeddings for cosine similarity - faiss.normalize_L2(embeddings.astype("float32")) - - # Build FAISS index - logger.info(f"Creating FAISS index with dimension {self.embedding_dim}...") - self.index = faiss.IndexFlatIP(self.embedding_dim) # Inner product for cosine - self.index.add(embeddings) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + embeddings = embeddings / norms + + # Prepare data for insertion + data = [] + for i, (text, category, embedding) in enumerate( + zip(self.texts, self.categories, embeddings) + ): + data.append( + { + "id": i, + "embedding": embedding.tolist(), + "category": category, + "text": text, + } + ) - # Save index - logger.info(f"Saving index to {self.index_path}") - faiss.write_index(self.index, self.index_path) + # Insert data in batches + batch_size = 100 + logger.info("Inserting data into Milvus...") + for i in range(0, len(data), batch_size): + batch = data[i : i + batch_size] + self.client.insert(collection_name=self.collection_name, data=batch) - logger.info(f"Index built successfully with {self.index.ntotal} vectors") + logger.info(f"Collection built successfully with {len(data)} vectors") def classify( self, text: str, k: int = 20, with_probabilities: bool = False @@ -282,14 +336,27 @@ def classify( """ # Generate embedding for query text query_embedding = self._encode_texts([text]) - faiss.normalize_L2(query_embedding.astype("float32")) - # Search for k nearest neighbors - similarities, indices = self.index.search(query_embedding, k) + # Normalize embeddings for cosine similarity + norms = np.linalg.norm(query_embedding, axis=1, keepdims=True) + query_embedding = query_embedding / norms + + # Search for k nearest neighbors using Milvus + results = self.client.search( + collection_name=self.collection_name, + data=query_embedding.tolist(), + limit=k, + output_fields=["category", "text"], + ) + + # Extract results + neighbor_categories = [] + neighbor_similarities = [] - # Get categories of nearest neighbors - neighbor_categories = [self.categories[idx] for idx in indices[0]] - neighbor_similarities = similarities[0] + for hits in results: + for hit in hits: + neighbor_categories.append(hit.get("entity", {}).get("category")) + neighbor_similarities.append(hit.get("distance", 0)) # Calculate confidence scores for each category category_scores = {cat: 0.0 for cat in self.category_names} @@ -417,6 +484,7 @@ def _decide_routing( # For multi-process deployments, each process gets its own instance. classifier = None classifier_device = "auto" # Default device setting +classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path def get_classifier(): @@ -426,13 +494,14 @@ def get_classifier(): # Get script directory script_dir = Path(__file__).parent csv_path = script_dir / "training_data.csv" - index_path = script_dir / "faiss_index.bin" + milvus_uri = script_dir / classifier_milvus_uri classifier = EmbeddingClassifier( model_name="Qwen/Qwen3-Embedding-0.6B", csv_path=str(csv_path), - index_path=str(index_path), + collection_name="embedding_classifier", device=classifier_device, + milvus_uri=str(milvus_uri), ) return classifier @@ -529,26 +598,32 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]: ] -async def main_stdio(device: str = "auto"): +async def main_stdio(device: str = "auto", milvus_uri: str = "./milvus_data.db"): """Run the MCP server in stdio mode.""" - global classifier_device + global classifier_device, classifier_milvus_uri classifier_device = device + classifier_milvus_uri = milvus_uri logger.info("Starting Embedding-Based MCP Classification Server (stdio mode)") clf = get_classifier() logger.info(f"Available categories: {', '.join(clf.category_names)}") logger.info(f"Model: {clf.model_name}") logger.info(f"Device: {clf.device}") - logger.info(f"Index size: {clf.index.ntotal} vectors") + logger.info(f"Milvus Lite: {clf.milvus_uri}") + stats = clf.client.get_collection_stats(clf.collection_name) + logger.info(f"Collection size: {stats.get('row_count', 0)} vectors") async with stdio_server() as (read_stream, write_stream): await app.run(read_stream, write_stream, app.create_initialization_options()) -async def main_http(port: int = 8091, device: str = "auto"): +async def main_http( + port: int = 8091, device: str = "auto", milvus_uri: str = "./milvus_data.db" +): """Run the MCP server in HTTP mode.""" - global classifier_device + global classifier_device, classifier_milvus_uri classifier_device = device + classifier_milvus_uri = milvus_uri try: from aiohttp import web @@ -563,7 +638,9 @@ async def main_http(port: int = 8091, device: str = "auto"): logger.info(f"Available categories: {', '.join(clf.category_names)}") logger.info(f"Model: {clf.model_name}") logger.info(f"Device: {clf.device}") - logger.info(f"Index size: {clf.index.ntotal} vectors") + logger.info(f"Milvus Lite: {clf.milvus_uri}") + stats = clf.client.get_collection_stats(clf.collection_name) + logger.info(f"Collection size: {stats.get('row_count', 0)} vectors") logger.info(f"Listening on http://0.0.0.0:{port}/mcp") async def handle_mcp_request(request): @@ -672,12 +749,14 @@ async def handle_mcp_request(request): async def health_check(request): """Health check endpoint.""" clf = get_classifier() + stats = clf.client.get_collection_stats(clf.collection_name) return web.json_response( { "status": "ok", "categories": clf.category_names, "model": clf.model_name, - "index_size": clf.index.ntotal, + "collection_size": stats.get("row_count", 0), + "milvus_uri": clf.milvus_uri, } ) @@ -724,7 +803,7 @@ async def health_check(request): # Parse command line arguments parser = argparse.ArgumentParser( - description="MCP Embedding-Based Classification Server" + description="MCP Embedding-Based Classification Server (Milvus Lite)" ) parser.add_argument( "--http", action="store_true", help="Run in HTTP mode instead of stdio" @@ -737,9 +816,15 @@ async def health_check(request): choices=["auto", "cuda", "cpu"], help="Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)", ) + parser.add_argument( + "--milvus-uri", + type=str, + default="./milvus_data.db", + help="Milvus Lite database file path (default: ./milvus_data.db)", + ) args = parser.parse_args() if args.http: - asyncio.run(main_http(args.port, args.device)) + asyncio.run(main_http(args.port, args.device, args.milvus_uri)) else: - asyncio.run(main_stdio(args.device)) + asyncio.run(main_stdio(args.device, args.milvus_uri))