661. Text classification using semantic embeddings (RAG-style)
772. Dynamic category discovery via list_categories
883. Intelligent routing decisions (model selection and reasoning control)
9- 4. FAISS vector database for similarity search
9+ 4. Milvus vector database for similarity search
1010
1111The server implements two MCP tools:
1212- 'list_categories': Returns available categories with per-category system prompts and descriptions
3838
3939 # HTTP mode (for semantic router)
4040 python server_embedding.py --http --port 8090
41+
42+ Prerequisites:
43+ - pip install pymilvus (includes Milvus Lite - no Docker needed!)
44+ - pip install torch transformers mcp
4145"""
4246
4347import argparse
4953from pathlib import Path
5054from typing import Any
5155
52- import faiss
5356import numpy as np
5457import torch
5558from mcp .server import Server
5659from mcp .server .stdio import stdio_server
5760from mcp .types import TextContent , Tool
5861from transformers import AutoModel , AutoTokenizer
62+ from pymilvus import MilvusClient
63+
5964
6065# Configure logging
6166logging .basicConfig (
114119
115120
116121class EmbeddingClassifier :
117- """Embedding-based text classifier using FAISS vector search."""
122+ """Embedding-based text classifier using Milvus vector search."""
118123
119124 def __init__ (
120125 self ,
121126 model_name : str = "Qwen/Qwen3-Embedding-0.6B" ,
122127 csv_path : str = "training_data.csv" ,
123- index_path : str = "faiss_index.bin " ,
128+ collection_name : str = "embedding_classifier " ,
124129 device : str = "auto" ,
130+ milvus_uri : str = "./milvus_data.db" ,
125131 ):
126132 """
127133 Initialize the embedding classifier.
128134
129135 Args:
130136 model_name: Name of the embedding model to use
131137 csv_path: Path to the CSV training data file
132- index_path: Path to save/load FAISS index
138+ collection_name: Name of the Milvus collection
133139 device: Device to use ("cuda", "cpu", or "auto" for auto-detection)
140+ milvus_uri: Milvus Lite database file path (default: "./milvus_data.db")
134141 """
135142 self .model_name = model_name
136143 self .csv_path = csv_path
137- self .index_path = index_path
144+ self .collection_name = collection_name
145+ self .milvus_uri = milvus_uri
138146
139147 logger .info (f"Initializing embedding model: { model_name } " )
140148 self .tokenizer = AutoTokenizer .from_pretrained (
@@ -161,7 +169,7 @@ def __init__(
161169 # Qwen3-Embedding-0.6B has embedding dimension of 1024
162170 self .embedding_dim = 1024
163171
164- self .index = None
172+ self .client = None
165173 self .category_names = list (CATEGORY_CONFIG .keys ())
166174 self .category_to_index = {
167175 name : idx for idx , name in enumerate (self .category_names )
@@ -172,21 +180,9 @@ def __init__(
172180 self .texts , self .categories = self ._load_csv_data ()
173181 logger .info (f"Loaded { len (self .texts )} training examples" )
174182
175- # Load or build FAISS index
176- if os .path .exists (index_path ):
177- logger .info (f"Loading existing FAISS index from { index_path } " )
178- self .index = faiss .read_index (index_path )
179- logger .info (f"Index loaded with { self .index .ntotal } vectors" )
180-
181- # Verify index size matches CSV
182- if self .index .ntotal != len (self .texts ):
183- logger .warning (
184- f"Index size ({ self .index .ntotal } ) doesn't match CSV ({ len (self .texts )} ). Rebuilding..."
185- )
186- self ._build_index ()
187- else :
188- logger .info ("No existing index found, building new one..." )
189- self ._build_index ()
183+ # Connect to Milvus Lite and initialize collection
184+ self ._connect_milvus ()
185+ self ._init_collection ()
190186
191187 def _encode_texts (self , texts : list [str ], batch_size : int = 8 ) -> np .ndarray :
192188 """
@@ -244,27 +240,83 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]:
244240 logger .info (f"Loaded { len (texts )} training examples" )
245241 return texts , categories
246242
247- def _build_index (self ):
248- """Build FAISS index from loaded CSV data."""
249- logger .info ("Building FAISS index from training data..." )
250-
243+ def _connect_milvus (self ):
244+ """Connect to Milvus Lite."""
245+ try :
246+ logger .info (f"Connecting to Milvus Lite at { self .milvus_uri } " )
247+ self .client = MilvusClient (self .milvus_uri )
248+ logger .info ("Successfully connected to Milvus Lite" )
249+ except Exception as e :
250+ logger .error (f"Failed to connect to Milvus Lite: { e } " )
251+ raise
252+
253+ def _init_collection (self ):
254+ """Initialize or load Milvus collection."""
255+ # Check if collection exists
256+ if self .client .has_collection (self .collection_name ):
257+ logger .info (f"Loading existing collection: { self .collection_name } " )
258+
259+ # Check if we need to rebuild (verify count matches CSV)
260+ stats = self .client .get_collection_stats (self .collection_name )
261+ current_count = stats .get ("row_count" , 0 )
262+ expected_count = len (self .texts )
263+
264+ if current_count != expected_count :
265+ logger .warning (
266+ f"Collection has { current_count } entities but CSV has { expected_count } . Rebuilding..."
267+ )
268+ self .client .drop_collection (self .collection_name )
269+ self ._create_and_build_collection ()
270+ else :
271+ logger .info (f"Collection loaded with { current_count } vectors" )
272+ else :
273+ logger .info (f"Collection { self .collection_name } not found, creating new one..." )
274+ self ._create_and_build_collection ()
275+
276+ def _create_and_build_collection (self ):
277+ """Create Milvus collection and insert embeddings."""
278+ logger .info ("Creating new Milvus collection..." )
279+
280+ # Create collection with schema
281+ self .client .create_collection (
282+ collection_name = self .collection_name ,
283+ dimension = self .embedding_dim ,
284+ metric_type = "IP" , # Inner Product for cosine similarity
285+ auto_id = False ,
286+ primary_field_name = "id" ,
287+ vector_field_name = "embedding" ,
288+ )
289+ logger .info (f"Collection { self .collection_name } created" )
290+
251291 # Generate embeddings
252292 logger .info (f"Generating embeddings for { len (self .texts )} examples..." )
253293 embeddings = self ._encode_texts (self .texts )
254-
294+
255295 # Normalize embeddings for cosine similarity
256- faiss .normalize_L2 (embeddings .astype ("float32" ))
257-
258- # Build FAISS index
259- logger .info (f"Creating FAISS index with dimension { self .embedding_dim } ..." )
260- self .index = faiss .IndexFlatIP (self .embedding_dim ) # Inner product for cosine
261- self .index .add (embeddings )
262-
263- # Save index
264- logger .info (f"Saving index to { self .index_path } " )
265- faiss .write_index (self .index , self .index_path )
266-
267- logger .info (f"Index built successfully with { self .index .ntotal } vectors" )
296+ norms = np .linalg .norm (embeddings , axis = 1 , keepdims = True )
297+ embeddings = embeddings / norms
298+
299+ # Prepare data for insertion
300+ data = []
301+ for i , (text , category , embedding ) in enumerate (zip (self .texts , self .categories , embeddings )):
302+ data .append ({
303+ "id" : i ,
304+ "embedding" : embedding .tolist (),
305+ "category" : category ,
306+ "text" : text ,
307+ })
308+
309+ # Insert data in batches
310+ batch_size = 100
311+ logger .info ("Inserting data into Milvus..." )
312+ for i in range (0 , len (data ), batch_size ):
313+ batch = data [i :i + batch_size ]
314+ self .client .insert (
315+ collection_name = self .collection_name ,
316+ data = batch
317+ )
318+
319+ logger .info (f"Collection built successfully with { len (data )} vectors" )
268320
269321 def classify (
270322 self , text : str , k : int = 20 , with_probabilities : bool = False
@@ -282,14 +334,27 @@ def classify(
282334 """
283335 # Generate embedding for query text
284336 query_embedding = self ._encode_texts ([text ])
285- faiss .normalize_L2 (query_embedding .astype ("float32" ))
286-
287- # Search for k nearest neighbors
288- similarities , indices = self .index .search (query_embedding , k )
337+
338+ # Normalize embeddings for cosine similarity
339+ norms = np .linalg .norm (query_embedding , axis = 1 , keepdims = True )
340+ query_embedding = query_embedding / norms
341+
342+ # Search for k nearest neighbors using Milvus
343+ results = self .client .search (
344+ collection_name = self .collection_name ,
345+ data = query_embedding .tolist (),
346+ limit = k ,
347+ output_fields = ["category" , "text" ]
348+ )
289349
290- # Get categories of nearest neighbors
291- neighbor_categories = [self .categories [idx ] for idx in indices [0 ]]
292- neighbor_similarities = similarities [0 ]
350+ # Extract results
351+ neighbor_categories = []
352+ neighbor_similarities = []
353+
354+ for hits in results :
355+ for hit in hits :
356+ neighbor_categories .append (hit .get ("entity" , {}).get ("category" ))
357+ neighbor_similarities .append (hit .get ("distance" , 0 ))
293358
294359 # Calculate confidence scores for each category
295360 category_scores = {cat : 0.0 for cat in self .category_names }
@@ -417,6 +482,7 @@ def _decide_routing(
417482# For multi-process deployments, each process gets its own instance.
418483classifier = None
419484classifier_device = "auto" # Default device setting
485+ classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path
420486
421487
422488def get_classifier ():
@@ -426,13 +492,14 @@ def get_classifier():
426492 # Get script directory
427493 script_dir = Path (__file__ ).parent
428494 csv_path = script_dir / "training_data.csv"
429- index_path = script_dir / "faiss_index.bin"
495+ milvus_uri = script_dir / classifier_milvus_uri
430496
431497 classifier = EmbeddingClassifier (
432498 model_name = "Qwen/Qwen3-Embedding-0.6B" ,
433499 csv_path = str (csv_path ),
434- index_path = str ( index_path ) ,
500+ collection_name = "embedding_classifier" ,
435501 device = classifier_device ,
502+ milvus_uri = str (milvus_uri ),
436503 )
437504 return classifier
438505
@@ -529,26 +596,30 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
529596 ]
530597
531598
532- async def main_stdio (device : str = "auto" ):
599+ async def main_stdio (device : str = "auto" , milvus_uri : str = "./milvus_data.db" ):
533600 """Run the MCP server in stdio mode."""
534- global classifier_device
601+ global classifier_device , classifier_milvus_uri
535602 classifier_device = device
603+ classifier_milvus_uri = milvus_uri
536604
537605 logger .info ("Starting Embedding-Based MCP Classification Server (stdio mode)" )
538606 clf = get_classifier ()
539607 logger .info (f"Available categories: { ', ' .join (clf .category_names )} " )
540608 logger .info (f"Model: { clf .model_name } " )
541609 logger .info (f"Device: { clf .device } " )
542- logger .info (f"Index size: { clf .index .ntotal } vectors" )
610+ logger .info (f"Milvus Lite: { clf .milvus_uri } " )
611+ stats = clf .client .get_collection_stats (clf .collection_name )
612+ logger .info (f"Collection size: { stats .get ('row_count' , 0 )} vectors" )
543613
544614 async with stdio_server () as (read_stream , write_stream ):
545615 await app .run (read_stream , write_stream , app .create_initialization_options ())
546616
547617
548- async def main_http (port : int = 8091 , device : str = "auto" ):
618+ async def main_http (port : int = 8091 , device : str = "auto" , milvus_uri : str = "./milvus_data.db" ):
549619 """Run the MCP server in HTTP mode."""
550- global classifier_device
620+ global classifier_device , classifier_milvus_uri
551621 classifier_device = device
622+ classifier_milvus_uri = milvus_uri
552623
553624 try :
554625 from aiohttp import web
@@ -563,7 +634,9 @@ async def main_http(port: int = 8091, device: str = "auto"):
563634 logger .info (f"Available categories: { ', ' .join (clf .category_names )} " )
564635 logger .info (f"Model: { clf .model_name } " )
565636 logger .info (f"Device: { clf .device } " )
566- logger .info (f"Index size: { clf .index .ntotal } vectors" )
637+ logger .info (f"Milvus Lite: { clf .milvus_uri } " )
638+ stats = clf .client .get_collection_stats (clf .collection_name )
639+ logger .info (f"Collection size: { stats .get ('row_count' , 0 )} vectors" )
567640 logger .info (f"Listening on http://0.0.0.0:{ port } /mcp" )
568641
569642 async def handle_mcp_request (request ):
@@ -672,12 +745,14 @@ async def handle_mcp_request(request):
672745 async def health_check (request ):
673746 """Health check endpoint."""
674747 clf = get_classifier ()
748+ stats = clf .client .get_collection_stats (clf .collection_name )
675749 return web .json_response (
676750 {
677751 "status" : "ok" ,
678752 "categories" : clf .category_names ,
679753 "model" : clf .model_name ,
680- "index_size" : clf .index .ntotal ,
754+ "collection_size" : stats .get ("row_count" , 0 ),
755+ "milvus_uri" : clf .milvus_uri ,
681756 }
682757 )
683758
@@ -724,7 +799,7 @@ async def health_check(request):
724799
725800 # Parse command line arguments
726801 parser = argparse .ArgumentParser (
727- description = "MCP Embedding-Based Classification Server"
802+ description = "MCP Embedding-Based Classification Server (Milvus Lite) "
728803 )
729804 parser .add_argument (
730805 "--http" , action = "store_true" , help = "Run in HTTP mode instead of stdio"
@@ -737,9 +812,15 @@ async def health_check(request):
737812 choices = ["auto" , "cuda" , "cpu" ],
738813 help = "Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)" ,
739814 )
815+ parser .add_argument (
816+ "--milvus-uri" ,
817+ type = str ,
818+ default = "./milvus_data.db" ,
819+ help = "Milvus Lite database file path (default: ./milvus_data.db)" ,
820+ )
740821 args = parser .parse_args ()
741822
742823 if args .http :
743- asyncio .run (main_http (args .port , args .device ))
824+ asyncio .run (main_http (args .port , args .device , args . milvus_uri ))
744825 else :
745- asyncio .run (main_stdio (args .device ))
826+ asyncio .run (main_stdio (args .device , args . milvus_uri ))
0 commit comments