Skip to content

Commit 3d63b17

Browse files
committed
use milvus vector database for mcp-classifier-server in examples
Signed-off-by: JackLCL <[email protected]>
1 parent 8b58623 commit 3d63b17

File tree

3 files changed

+142
-61
lines changed

3 files changed

+142
-61
lines changed

examples/mcp-classifier-server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ This directory contains **two MCP classification servers**:
1616
### 2. **Embedding-Based Server** (`server_embedding.py`) 🆕
1717

1818
-**High Accuracy** - Semantic understanding with Qwen3-Embedding-0.6B
19-
-**RAG-Style** - FAISS vector database with similarity search
19+
-**RAG-Style** - Milvus vector database with similarity search
2020
-**Flexible** - Handles paraphrases, synonyms, variations
2121
- 📝 **Best For**: Production use, high-accuracy requirements
2222

@@ -210,7 +210,7 @@ python3 server_embedding.py --http --port 8090
210210
### Features
211211

212212
- **Qwen3-Embedding-0.6B** model with 1024-dimensional embeddings
213-
- **FAISS vector database** for fast similarity search
213+
- **Milvus vector database** for fast similarity search
214214
- **RAG-style classification** using 95 training examples
215215
- **Same MCP protocol** as regex server (drop-in replacement)
216216
- **Higher accuracy** - Understands semantic meaning, not just patterns

examples/mcp-classifier-server/requirements_embedding.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mcp>=1.0.0
66
# Embedding and Vector Search
77
transformers>=4.30.0
88
torch>=2.0.0
9-
faiss-cpu>=1.7.4 # Use faiss-gpu if you have GPU support
9+
pymilvus>=2.5.0
1010

1111
# HTTP server support (optional, for HTTP mode)
1212
aiohttp>=3.9.0

examples/mcp-classifier-server/server_embedding.py

Lines changed: 139 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
1. Text classification using semantic embeddings (RAG-style)
77
2. Dynamic category discovery via list_categories
88
3. Intelligent routing decisions (model selection and reasoning control)
9-
4. FAISS vector database for similarity search
9+
4. Milvus vector database for similarity search
1010
1111
The server implements two MCP tools:
1212
- 'list_categories': Returns available categories with per-category system prompts and descriptions
@@ -38,6 +38,10 @@
3838
3939
# HTTP mode (for semantic router)
4040
python server_embedding.py --http --port 8090
41+
42+
Prerequisites:
43+
- pip install pymilvus (includes Milvus Lite - no Docker needed!)
44+
- pip install torch transformers mcp
4145
"""
4246

4347
import argparse
@@ -49,13 +53,14 @@
4953
from pathlib import Path
5054
from typing import Any
5155

52-
import faiss
5356
import numpy as np
5457
import torch
5558
from mcp.server import Server
5659
from mcp.server.stdio import stdio_server
5760
from mcp.types import TextContent, Tool
5861
from transformers import AutoModel, AutoTokenizer
62+
from pymilvus import MilvusClient
63+
5964

6065
# Configure logging
6166
logging.basicConfig(
@@ -114,27 +119,30 @@
114119

115120

116121
class EmbeddingClassifier:
117-
"""Embedding-based text classifier using FAISS vector search."""
122+
"""Embedding-based text classifier using Milvus vector search."""
118123

119124
def __init__(
120125
self,
121126
model_name: str = "Qwen/Qwen3-Embedding-0.6B",
122127
csv_path: str = "training_data.csv",
123-
index_path: str = "faiss_index.bin",
128+
collection_name: str = "embedding_classifier",
124129
device: str = "auto",
130+
milvus_uri: str = "./milvus_data.db",
125131
):
126132
"""
127133
Initialize the embedding classifier.
128134
129135
Args:
130136
model_name: Name of the embedding model to use
131137
csv_path: Path to the CSV training data file
132-
index_path: Path to save/load FAISS index
138+
collection_name: Name of the Milvus collection
133139
device: Device to use ("cuda", "cpu", or "auto" for auto-detection)
140+
milvus_uri: Milvus Lite database file path (default: "./milvus_data.db")
134141
"""
135142
self.model_name = model_name
136143
self.csv_path = csv_path
137-
self.index_path = index_path
144+
self.collection_name = collection_name
145+
self.milvus_uri = milvus_uri
138146

139147
logger.info(f"Initializing embedding model: {model_name}")
140148
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -161,7 +169,7 @@ def __init__(
161169
# Qwen3-Embedding-0.6B has embedding dimension of 1024
162170
self.embedding_dim = 1024
163171

164-
self.index = None
172+
self.client = None
165173
self.category_names = list(CATEGORY_CONFIG.keys())
166174
self.category_to_index = {
167175
name: idx for idx, name in enumerate(self.category_names)
@@ -172,21 +180,9 @@ def __init__(
172180
self.texts, self.categories = self._load_csv_data()
173181
logger.info(f"Loaded {len(self.texts)} training examples")
174182

175-
# Load or build FAISS index
176-
if os.path.exists(index_path):
177-
logger.info(f"Loading existing FAISS index from {index_path}")
178-
self.index = faiss.read_index(index_path)
179-
logger.info(f"Index loaded with {self.index.ntotal} vectors")
180-
181-
# Verify index size matches CSV
182-
if self.index.ntotal != len(self.texts):
183-
logger.warning(
184-
f"Index size ({self.index.ntotal}) doesn't match CSV ({len(self.texts)}). Rebuilding..."
185-
)
186-
self._build_index()
187-
else:
188-
logger.info("No existing index found, building new one...")
189-
self._build_index()
183+
# Connect to Milvus Lite and initialize collection
184+
self._connect_milvus()
185+
self._init_collection()
190186

191187
def _encode_texts(self, texts: list[str], batch_size: int = 8) -> np.ndarray:
192188
"""
@@ -244,27 +240,83 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]:
244240
logger.info(f"Loaded {len(texts)} training examples")
245241
return texts, categories
246242

247-
def _build_index(self):
248-
"""Build FAISS index from loaded CSV data."""
249-
logger.info("Building FAISS index from training data...")
250-
243+
def _connect_milvus(self):
244+
"""Connect to Milvus Lite."""
245+
try:
246+
logger.info(f"Connecting to Milvus Lite at {self.milvus_uri}")
247+
self.client = MilvusClient(self.milvus_uri)
248+
logger.info("Successfully connected to Milvus Lite")
249+
except Exception as e:
250+
logger.error(f"Failed to connect to Milvus Lite: {e}")
251+
raise
252+
253+
def _init_collection(self):
254+
"""Initialize or load Milvus collection."""
255+
# Check if collection exists
256+
if self.client.has_collection(self.collection_name):
257+
logger.info(f"Loading existing collection: {self.collection_name}")
258+
259+
# Check if we need to rebuild (verify count matches CSV)
260+
stats = self.client.get_collection_stats(self.collection_name)
261+
current_count = stats.get("row_count", 0)
262+
expected_count = len(self.texts)
263+
264+
if current_count != expected_count:
265+
logger.warning(
266+
f"Collection has {current_count} entities but CSV has {expected_count}. Rebuilding..."
267+
)
268+
self.client.drop_collection(self.collection_name)
269+
self._create_and_build_collection()
270+
else:
271+
logger.info(f"Collection loaded with {current_count} vectors")
272+
else:
273+
logger.info(f"Collection {self.collection_name} not found, creating new one...")
274+
self._create_and_build_collection()
275+
276+
def _create_and_build_collection(self):
277+
"""Create Milvus collection and insert embeddings."""
278+
logger.info("Creating new Milvus collection...")
279+
280+
# Create collection with schema
281+
self.client.create_collection(
282+
collection_name=self.collection_name,
283+
dimension=self.embedding_dim,
284+
metric_type="IP", # Inner Product for cosine similarity
285+
auto_id=False,
286+
primary_field_name="id",
287+
vector_field_name="embedding",
288+
)
289+
logger.info(f"Collection {self.collection_name} created")
290+
251291
# Generate embeddings
252292
logger.info(f"Generating embeddings for {len(self.texts)} examples...")
253293
embeddings = self._encode_texts(self.texts)
254-
294+
255295
# Normalize embeddings for cosine similarity
256-
faiss.normalize_L2(embeddings.astype("float32"))
257-
258-
# Build FAISS index
259-
logger.info(f"Creating FAISS index with dimension {self.embedding_dim}...")
260-
self.index = faiss.IndexFlatIP(self.embedding_dim) # Inner product for cosine
261-
self.index.add(embeddings)
262-
263-
# Save index
264-
logger.info(f"Saving index to {self.index_path}")
265-
faiss.write_index(self.index, self.index_path)
266-
267-
logger.info(f"Index built successfully with {self.index.ntotal} vectors")
296+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
297+
embeddings = embeddings / norms
298+
299+
# Prepare data for insertion
300+
data = []
301+
for i, (text, category, embedding) in enumerate(zip(self.texts, self.categories, embeddings)):
302+
data.append({
303+
"id": i,
304+
"embedding": embedding.tolist(),
305+
"category": category,
306+
"text": text,
307+
})
308+
309+
# Insert data in batches
310+
batch_size = 100
311+
logger.info("Inserting data into Milvus...")
312+
for i in range(0, len(data), batch_size):
313+
batch = data[i:i + batch_size]
314+
self.client.insert(
315+
collection_name=self.collection_name,
316+
data=batch
317+
)
318+
319+
logger.info(f"Collection built successfully with {len(data)} vectors")
268320

269321
def classify(
270322
self, text: str, k: int = 20, with_probabilities: bool = False
@@ -282,14 +334,27 @@ def classify(
282334
"""
283335
# Generate embedding for query text
284336
query_embedding = self._encode_texts([text])
285-
faiss.normalize_L2(query_embedding.astype("float32"))
286-
287-
# Search for k nearest neighbors
288-
similarities, indices = self.index.search(query_embedding, k)
337+
338+
# Normalize embeddings for cosine similarity
339+
norms = np.linalg.norm(query_embedding, axis=1, keepdims=True)
340+
query_embedding = query_embedding / norms
341+
342+
# Search for k nearest neighbors using Milvus
343+
results = self.client.search(
344+
collection_name=self.collection_name,
345+
data=query_embedding.tolist(),
346+
limit=k,
347+
output_fields=["category", "text"]
348+
)
289349

290-
# Get categories of nearest neighbors
291-
neighbor_categories = [self.categories[idx] for idx in indices[0]]
292-
neighbor_similarities = similarities[0]
350+
# Extract results
351+
neighbor_categories = []
352+
neighbor_similarities = []
353+
354+
for hits in results:
355+
for hit in hits:
356+
neighbor_categories.append(hit.get("entity", {}).get("category"))
357+
neighbor_similarities.append(hit.get("distance", 0))
293358

294359
# Calculate confidence scores for each category
295360
category_scores = {cat: 0.0 for cat in self.category_names}
@@ -417,6 +482,7 @@ def _decide_routing(
417482
# For multi-process deployments, each process gets its own instance.
418483
classifier = None
419484
classifier_device = "auto" # Default device setting
485+
classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path
420486

421487

422488
def get_classifier():
@@ -426,13 +492,14 @@ def get_classifier():
426492
# Get script directory
427493
script_dir = Path(__file__).parent
428494
csv_path = script_dir / "training_data.csv"
429-
index_path = script_dir / "faiss_index.bin"
495+
milvus_uri = script_dir / classifier_milvus_uri
430496

431497
classifier = EmbeddingClassifier(
432498
model_name="Qwen/Qwen3-Embedding-0.6B",
433499
csv_path=str(csv_path),
434-
index_path=str(index_path),
500+
collection_name="embedding_classifier",
435501
device=classifier_device,
502+
milvus_uri=str(milvus_uri),
436503
)
437504
return classifier
438505

@@ -529,26 +596,30 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
529596
]
530597

531598

532-
async def main_stdio(device: str = "auto"):
599+
async def main_stdio(device: str = "auto", milvus_uri: str = "./milvus_data.db"):
533600
"""Run the MCP server in stdio mode."""
534-
global classifier_device
601+
global classifier_device, classifier_milvus_uri
535602
classifier_device = device
603+
classifier_milvus_uri = milvus_uri
536604

537605
logger.info("Starting Embedding-Based MCP Classification Server (stdio mode)")
538606
clf = get_classifier()
539607
logger.info(f"Available categories: {', '.join(clf.category_names)}")
540608
logger.info(f"Model: {clf.model_name}")
541609
logger.info(f"Device: {clf.device}")
542-
logger.info(f"Index size: {clf.index.ntotal} vectors")
610+
logger.info(f"Milvus Lite: {clf.milvus_uri}")
611+
stats = clf.client.get_collection_stats(clf.collection_name)
612+
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")
543613

544614
async with stdio_server() as (read_stream, write_stream):
545615
await app.run(read_stream, write_stream, app.create_initialization_options())
546616

547617

548-
async def main_http(port: int = 8091, device: str = "auto"):
618+
async def main_http(port: int = 8091, device: str = "auto", milvus_uri: str = "./milvus_data.db"):
549619
"""Run the MCP server in HTTP mode."""
550-
global classifier_device
620+
global classifier_device, classifier_milvus_uri
551621
classifier_device = device
622+
classifier_milvus_uri = milvus_uri
552623

553624
try:
554625
from aiohttp import web
@@ -563,7 +634,9 @@ async def main_http(port: int = 8091, device: str = "auto"):
563634
logger.info(f"Available categories: {', '.join(clf.category_names)}")
564635
logger.info(f"Model: {clf.model_name}")
565636
logger.info(f"Device: {clf.device}")
566-
logger.info(f"Index size: {clf.index.ntotal} vectors")
637+
logger.info(f"Milvus Lite: {clf.milvus_uri}")
638+
stats = clf.client.get_collection_stats(clf.collection_name)
639+
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")
567640
logger.info(f"Listening on http://0.0.0.0:{port}/mcp")
568641

569642
async def handle_mcp_request(request):
@@ -672,12 +745,14 @@ async def handle_mcp_request(request):
672745
async def health_check(request):
673746
"""Health check endpoint."""
674747
clf = get_classifier()
748+
stats = clf.client.get_collection_stats(clf.collection_name)
675749
return web.json_response(
676750
{
677751
"status": "ok",
678752
"categories": clf.category_names,
679753
"model": clf.model_name,
680-
"index_size": clf.index.ntotal,
754+
"collection_size": stats.get("row_count", 0),
755+
"milvus_uri": clf.milvus_uri,
681756
}
682757
)
683758

@@ -724,7 +799,7 @@ async def health_check(request):
724799

725800
# Parse command line arguments
726801
parser = argparse.ArgumentParser(
727-
description="MCP Embedding-Based Classification Server"
802+
description="MCP Embedding-Based Classification Server (Milvus Lite)"
728803
)
729804
parser.add_argument(
730805
"--http", action="store_true", help="Run in HTTP mode instead of stdio"
@@ -737,9 +812,15 @@ async def health_check(request):
737812
choices=["auto", "cuda", "cpu"],
738813
help="Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)",
739814
)
815+
parser.add_argument(
816+
"--milvus-uri",
817+
type=str,
818+
default="./milvus_data.db",
819+
help="Milvus Lite database file path (default: ./milvus_data.db)",
820+
)
740821
args = parser.parse_args()
741822

742823
if args.http:
743-
asyncio.run(main_http(args.port, args.device))
824+
asyncio.run(main_http(args.port, args.device, args.milvus_uri))
744825
else:
745-
asyncio.run(main_stdio(args.device))
826+
asyncio.run(main_stdio(args.device, args.milvus_uri))

0 commit comments

Comments
 (0)