Skip to content

Commit c9bcbce

Browse files
committed
use milvus vector database for mcp-classifier-server in examples
Signed-off-by: JackLCL <[email protected]>
1 parent 8b58623 commit c9bcbce

File tree

3 files changed

+141
-56
lines changed

3 files changed

+141
-56
lines changed

examples/mcp-classifier-server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ This directory contains **two MCP classification servers**:
1616
### 2. **Embedding-Based Server** (`server_embedding.py`) 🆕
1717

1818
-**High Accuracy** - Semantic understanding with Qwen3-Embedding-0.6B
19-
-**RAG-Style** - FAISS vector database with similarity search
19+
-**RAG-Style** - Milvus vector database with similarity search
2020
-**Flexible** - Handles paraphrases, synonyms, variations
2121
- 📝 **Best For**: Production use, high-accuracy requirements
2222

@@ -210,7 +210,7 @@ python3 server_embedding.py --http --port 8090
210210
### Features
211211

212212
- **Qwen3-Embedding-0.6B** model with 1024-dimensional embeddings
213-
- **FAISS vector database** for fast similarity search
213+
- **Milvus vector database** for fast similarity search
214214
- **RAG-style classification** using 95 training examples
215215
- **Same MCP protocol** as regex server (drop-in replacement)
216216
- **Higher accuracy** - Understands semantic meaning, not just patterns

examples/mcp-classifier-server/requirements_embedding.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mcp>=1.0.0
66
# Embedding and Vector Search
77
transformers>=4.30.0
88
torch>=2.0.0
9-
faiss-cpu>=1.7.4 # Use faiss-gpu if you have GPU support
9+
pymilvus>=2.5.0
1010

1111
# HTTP server support (optional, for HTTP mode)
1212
aiohttp>=3.9.0

examples/mcp-classifier-server/server_embedding.py

Lines changed: 138 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
1. Text classification using semantic embeddings (RAG-style)
77
2. Dynamic category discovery via list_categories
88
3. Intelligent routing decisions (model selection and reasoning control)
9-
4. FAISS vector database for similarity search
9+
4. Milvus vector database for similarity search
1010
1111
The server implements two MCP tools:
1212
- 'list_categories': Returns available categories with per-category system prompts and descriptions
@@ -38,6 +38,10 @@
3838
3939
# HTTP mode (for semantic router)
4040
python server_embedding.py --http --port 8090
41+
42+
Prerequisites:
43+
- pip install pymilvus (includes Milvus Lite - no Docker needed!)
44+
- pip install torch transformers mcp
4145
"""
4246

4347
import argparse
@@ -49,12 +53,12 @@
4953
from pathlib import Path
5054
from typing import Any
5155

52-
import faiss
5356
import numpy as np
5457
import torch
5558
from mcp.server import Server
5659
from mcp.server.stdio import stdio_server
5760
from mcp.types import TextContent, Tool
61+
from pymilvus import MilvusClient
5862
from transformers import AutoModel, AutoTokenizer
5963

6064
# Configure logging
@@ -114,27 +118,30 @@
114118

115119

116120
class EmbeddingClassifier:
117-
"""Embedding-based text classifier using FAISS vector search."""
121+
"""Embedding-based text classifier using Milvus vector search."""
118122

119123
def __init__(
120124
self,
121125
model_name: str = "Qwen/Qwen3-Embedding-0.6B",
122126
csv_path: str = "training_data.csv",
123-
index_path: str = "faiss_index.bin",
127+
collection_name: str = "embedding_classifier",
124128
device: str = "auto",
129+
milvus_uri: str = "./milvus_data.db",
125130
):
126131
"""
127132
Initialize the embedding classifier.
128133
129134
Args:
130135
model_name: Name of the embedding model to use
131136
csv_path: Path to the CSV training data file
132-
index_path: Path to save/load FAISS index
137+
collection_name: Name of the Milvus collection
133138
device: Device to use ("cuda", "cpu", or "auto" for auto-detection)
139+
milvus_uri: Milvus Lite database file path (default: "./milvus_data.db")
134140
"""
135141
self.model_name = model_name
136142
self.csv_path = csv_path
137-
self.index_path = index_path
143+
self.collection_name = collection_name
144+
self.milvus_uri = milvus_uri
138145

139146
logger.info(f"Initializing embedding model: {model_name}")
140147
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -161,7 +168,7 @@ def __init__(
161168
# Qwen3-Embedding-0.6B has embedding dimension of 1024
162169
self.embedding_dim = 1024
163170

164-
self.index = None
171+
self.client = None
165172
self.category_names = list(CATEGORY_CONFIG.keys())
166173
self.category_to_index = {
167174
name: idx for idx, name in enumerate(self.category_names)
@@ -172,21 +179,9 @@ def __init__(
172179
self.texts, self.categories = self._load_csv_data()
173180
logger.info(f"Loaded {len(self.texts)} training examples")
174181

175-
# Load or build FAISS index
176-
if os.path.exists(index_path):
177-
logger.info(f"Loading existing FAISS index from {index_path}")
178-
self.index = faiss.read_index(index_path)
179-
logger.info(f"Index loaded with {self.index.ntotal} vectors")
180-
181-
# Verify index size matches CSV
182-
if self.index.ntotal != len(self.texts):
183-
logger.warning(
184-
f"Index size ({self.index.ntotal}) doesn't match CSV ({len(self.texts)}). Rebuilding..."
185-
)
186-
self._build_index()
187-
else:
188-
logger.info("No existing index found, building new one...")
189-
self._build_index()
182+
# Connect to Milvus Lite and initialize collection
183+
self._connect_milvus()
184+
self._init_collection()
190185

191186
def _encode_texts(self, texts: list[str], batch_size: int = 8) -> np.ndarray:
192187
"""
@@ -244,27 +239,86 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]:
244239
logger.info(f"Loaded {len(texts)} training examples")
245240
return texts, categories
246241

247-
def _build_index(self):
248-
"""Build FAISS index from loaded CSV data."""
249-
logger.info("Building FAISS index from training data...")
242+
def _connect_milvus(self):
243+
"""Connect to Milvus Lite."""
244+
try:
245+
logger.info(f"Connecting to Milvus Lite at {self.milvus_uri}")
246+
self.client = MilvusClient(self.milvus_uri)
247+
logger.info("Successfully connected to Milvus Lite")
248+
except Exception as e:
249+
logger.error(f"Failed to connect to Milvus Lite: {e}")
250+
raise
251+
252+
def _init_collection(self):
253+
"""Initialize or load Milvus collection."""
254+
# Check if collection exists
255+
if self.client.has_collection(self.collection_name):
256+
logger.info(f"Loading existing collection: {self.collection_name}")
257+
258+
# Check if we need to rebuild (verify count matches CSV)
259+
stats = self.client.get_collection_stats(self.collection_name)
260+
current_count = stats.get("row_count", 0)
261+
expected_count = len(self.texts)
262+
263+
if current_count != expected_count:
264+
logger.warning(
265+
f"Collection has {current_count} entities but CSV has {expected_count}. Rebuilding..."
266+
)
267+
self.client.drop_collection(self.collection_name)
268+
self._create_and_build_collection()
269+
else:
270+
logger.info(f"Collection loaded with {current_count} vectors")
271+
else:
272+
logger.info(
273+
f"Collection {self.collection_name} not found, creating new one..."
274+
)
275+
self._create_and_build_collection()
276+
277+
def _create_and_build_collection(self):
278+
"""Create Milvus collection and insert embeddings."""
279+
logger.info("Creating new Milvus collection...")
280+
281+
# Create collection with schema
282+
self.client.create_collection(
283+
collection_name=self.collection_name,
284+
dimension=self.embedding_dim,
285+
metric_type="IP", # Inner Product for cosine similarity
286+
auto_id=False,
287+
primary_field_name="id",
288+
vector_field_name="embedding",
289+
)
290+
logger.info(f"Collection {self.collection_name} created")
250291

251292
# Generate embeddings
252293
logger.info(f"Generating embeddings for {len(self.texts)} examples...")
253294
embeddings = self._encode_texts(self.texts)
254295

255296
# Normalize embeddings for cosine similarity
256-
faiss.normalize_L2(embeddings.astype("float32"))
257-
258-
# Build FAISS index
259-
logger.info(f"Creating FAISS index with dimension {self.embedding_dim}...")
260-
self.index = faiss.IndexFlatIP(self.embedding_dim) # Inner product for cosine
261-
self.index.add(embeddings)
297+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
298+
embeddings = embeddings / norms
299+
300+
# Prepare data for insertion
301+
data = []
302+
for i, (text, category, embedding) in enumerate(
303+
zip(self.texts, self.categories, embeddings)
304+
):
305+
data.append(
306+
{
307+
"id": i,
308+
"embedding": embedding.tolist(),
309+
"category": category,
310+
"text": text,
311+
}
312+
)
262313

263-
# Save index
264-
logger.info(f"Saving index to {self.index_path}")
265-
faiss.write_index(self.index, self.index_path)
314+
# Insert data in batches
315+
batch_size = 100
316+
logger.info("Inserting data into Milvus...")
317+
for i in range(0, len(data), batch_size):
318+
batch = data[i : i + batch_size]
319+
self.client.insert(collection_name=self.collection_name, data=batch)
266320

267-
logger.info(f"Index built successfully with {self.index.ntotal} vectors")
321+
logger.info(f"Collection built successfully with {len(data)} vectors")
268322

269323
def classify(
270324
self, text: str, k: int = 20, with_probabilities: bool = False
@@ -282,14 +336,27 @@ def classify(
282336
"""
283337
# Generate embedding for query text
284338
query_embedding = self._encode_texts([text])
285-
faiss.normalize_L2(query_embedding.astype("float32"))
286339

287-
# Search for k nearest neighbors
288-
similarities, indices = self.index.search(query_embedding, k)
340+
# Normalize embeddings for cosine similarity
341+
norms = np.linalg.norm(query_embedding, axis=1, keepdims=True)
342+
query_embedding = query_embedding / norms
343+
344+
# Search for k nearest neighbors using Milvus
345+
results = self.client.search(
346+
collection_name=self.collection_name,
347+
data=query_embedding.tolist(),
348+
limit=k,
349+
output_fields=["category", "text"],
350+
)
351+
352+
# Extract results
353+
neighbor_categories = []
354+
neighbor_similarities = []
289355

290-
# Get categories of nearest neighbors
291-
neighbor_categories = [self.categories[idx] for idx in indices[0]]
292-
neighbor_similarities = similarities[0]
356+
for hits in results:
357+
for hit in hits:
358+
neighbor_categories.append(hit.get("entity", {}).get("category"))
359+
neighbor_similarities.append(hit.get("distance", 0))
293360

294361
# Calculate confidence scores for each category
295362
category_scores = {cat: 0.0 for cat in self.category_names}
@@ -417,6 +484,7 @@ def _decide_routing(
417484
# For multi-process deployments, each process gets its own instance.
418485
classifier = None
419486
classifier_device = "auto" # Default device setting
487+
classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path
420488

421489

422490
def get_classifier():
@@ -426,13 +494,14 @@ def get_classifier():
426494
# Get script directory
427495
script_dir = Path(__file__).parent
428496
csv_path = script_dir / "training_data.csv"
429-
index_path = script_dir / "faiss_index.bin"
497+
milvus_uri = script_dir / classifier_milvus_uri
430498

431499
classifier = EmbeddingClassifier(
432500
model_name="Qwen/Qwen3-Embedding-0.6B",
433501
csv_path=str(csv_path),
434-
index_path=str(index_path),
502+
collection_name="embedding_classifier",
435503
device=classifier_device,
504+
milvus_uri=str(milvus_uri),
436505
)
437506
return classifier
438507

@@ -529,26 +598,32 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
529598
]
530599

531600

532-
async def main_stdio(device: str = "auto"):
601+
async def main_stdio(device: str = "auto", milvus_uri: str = "./milvus_data.db"):
533602
"""Run the MCP server in stdio mode."""
534-
global classifier_device
603+
global classifier_device, classifier_milvus_uri
535604
classifier_device = device
605+
classifier_milvus_uri = milvus_uri
536606

537607
logger.info("Starting Embedding-Based MCP Classification Server (stdio mode)")
538608
clf = get_classifier()
539609
logger.info(f"Available categories: {', '.join(clf.category_names)}")
540610
logger.info(f"Model: {clf.model_name}")
541611
logger.info(f"Device: {clf.device}")
542-
logger.info(f"Index size: {clf.index.ntotal} vectors")
612+
logger.info(f"Milvus Lite: {clf.milvus_uri}")
613+
stats = clf.client.get_collection_stats(clf.collection_name)
614+
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")
543615

544616
async with stdio_server() as (read_stream, write_stream):
545617
await app.run(read_stream, write_stream, app.create_initialization_options())
546618

547619

548-
async def main_http(port: int = 8091, device: str = "auto"):
620+
async def main_http(
621+
port: int = 8091, device: str = "auto", milvus_uri: str = "./milvus_data.db"
622+
):
549623
"""Run the MCP server in HTTP mode."""
550-
global classifier_device
624+
global classifier_device, classifier_milvus_uri
551625
classifier_device = device
626+
classifier_milvus_uri = milvus_uri
552627

553628
try:
554629
from aiohttp import web
@@ -563,7 +638,9 @@ async def main_http(port: int = 8091, device: str = "auto"):
563638
logger.info(f"Available categories: {', '.join(clf.category_names)}")
564639
logger.info(f"Model: {clf.model_name}")
565640
logger.info(f"Device: {clf.device}")
566-
logger.info(f"Index size: {clf.index.ntotal} vectors")
641+
logger.info(f"Milvus Lite: {clf.milvus_uri}")
642+
stats = clf.client.get_collection_stats(clf.collection_name)
643+
logger.info(f"Collection size: {stats.get('row_count', 0)} vectors")
567644
logger.info(f"Listening on http://0.0.0.0:{port}/mcp")
568645

569646
async def handle_mcp_request(request):
@@ -672,12 +749,14 @@ async def handle_mcp_request(request):
672749
async def health_check(request):
673750
"""Health check endpoint."""
674751
clf = get_classifier()
752+
stats = clf.client.get_collection_stats(clf.collection_name)
675753
return web.json_response(
676754
{
677755
"status": "ok",
678756
"categories": clf.category_names,
679757
"model": clf.model_name,
680-
"index_size": clf.index.ntotal,
758+
"collection_size": stats.get("row_count", 0),
759+
"milvus_uri": clf.milvus_uri,
681760
}
682761
)
683762

@@ -724,7 +803,7 @@ async def health_check(request):
724803

725804
# Parse command line arguments
726805
parser = argparse.ArgumentParser(
727-
description="MCP Embedding-Based Classification Server"
806+
description="MCP Embedding-Based Classification Server (Milvus Lite)"
728807
)
729808
parser.add_argument(
730809
"--http", action="store_true", help="Run in HTTP mode instead of stdio"
@@ -737,9 +816,15 @@ async def health_check(request):
737816
choices=["auto", "cuda", "cpu"],
738817
help="Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)",
739818
)
819+
parser.add_argument(
820+
"--milvus-uri",
821+
type=str,
822+
default="./milvus_data.db",
823+
help="Milvus Lite database file path (default: ./milvus_data.db)",
824+
)
740825
args = parser.parse_args()
741826

742827
if args.http:
743-
asyncio.run(main_http(args.port, args.device))
828+
asyncio.run(main_http(args.port, args.device, args.milvus_uri))
744829
else:
745-
asyncio.run(main_stdio(args.device))
830+
asyncio.run(main_stdio(args.device, args.milvus_uri))

0 commit comments

Comments
 (0)