661. Text classification using semantic embeddings (RAG-style)
772. Dynamic category discovery via list_categories
883. Intelligent routing decisions (model selection and reasoning control)
9- 4. FAISS vector database for similarity search
9+ 4. Milvus vector database for similarity search
1010
1111The server implements two MCP tools:
1212- 'list_categories': Returns available categories with per-category system prompts and descriptions
3838
3939 # HTTP mode (for semantic router)
4040 python server_embedding.py --http --port 8090
41+
42+ Prerequisites:
43+ - pip install pymilvus (includes Milvus Lite - no Docker needed!)
44+ - pip install torch transformers mcp
4145"""
4246
4347import argparse
4953from pathlib import Path
5054from typing import Any
5155
52- import faiss
5356import numpy as np
5457import torch
5558from mcp .server import Server
5659from mcp .server .stdio import stdio_server
5760from mcp .types import TextContent , Tool
61+ from pymilvus import MilvusClient
5862from transformers import AutoModel , AutoTokenizer
5963
6064# Configure logging
114118
115119
116120class EmbeddingClassifier :
117- """Embedding-based text classifier using FAISS vector search."""
121+ """Embedding-based text classifier using Milvus vector search."""
118122
119123 def __init__ (
120124 self ,
121125 model_name : str = "Qwen/Qwen3-Embedding-0.6B" ,
122126 csv_path : str = "training_data.csv" ,
123- index_path : str = "faiss_index.bin " ,
127+ collection_name : str = "embedding_classifier " ,
124128 device : str = "auto" ,
129+ milvus_uri : str = "./milvus_data.db" ,
125130 ):
126131 """
127132 Initialize the embedding classifier.
128133
129134 Args:
130135 model_name: Name of the embedding model to use
131136 csv_path: Path to the CSV training data file
132- index_path: Path to save/load FAISS index
137+ collection_name: Name of the Milvus collection
133138 device: Device to use ("cuda", "cpu", or "auto" for auto-detection)
139+ milvus_uri: Milvus Lite database file path (default: "./milvus_data.db")
134140 """
135141 self .model_name = model_name
136142 self .csv_path = csv_path
137- self .index_path = index_path
143+ self .collection_name = collection_name
144+ self .milvus_uri = milvus_uri
138145
139146 logger .info (f"Initializing embedding model: { model_name } " )
140147 self .tokenizer = AutoTokenizer .from_pretrained (
@@ -161,7 +168,7 @@ def __init__(
161168 # Qwen3-Embedding-0.6B has embedding dimension of 1024
162169 self .embedding_dim = 1024
163170
164- self .index = None
171+ self .client = None
165172 self .category_names = list (CATEGORY_CONFIG .keys ())
166173 self .category_to_index = {
167174 name : idx for idx , name in enumerate (self .category_names )
@@ -172,21 +179,9 @@ def __init__(
172179 self .texts , self .categories = self ._load_csv_data ()
173180 logger .info (f"Loaded { len (self .texts )} training examples" )
174181
175- # Load or build FAISS index
176- if os .path .exists (index_path ):
177- logger .info (f"Loading existing FAISS index from { index_path } " )
178- self .index = faiss .read_index (index_path )
179- logger .info (f"Index loaded with { self .index .ntotal } vectors" )
180-
181- # Verify index size matches CSV
182- if self .index .ntotal != len (self .texts ):
183- logger .warning (
184- f"Index size ({ self .index .ntotal } ) doesn't match CSV ({ len (self .texts )} ). Rebuilding..."
185- )
186- self ._build_index ()
187- else :
188- logger .info ("No existing index found, building new one..." )
189- self ._build_index ()
182+ # Connect to Milvus Lite and initialize collection
183+ self ._connect_milvus ()
184+ self ._init_collection ()
190185
191186 def _encode_texts (self , texts : list [str ], batch_size : int = 8 ) -> np .ndarray :
192187 """
@@ -244,27 +239,86 @@ def _load_csv_data(self) -> tuple[list[str], list[str]]:
244239 logger .info (f"Loaded { len (texts )} training examples" )
245240 return texts , categories
246241
247- def _build_index (self ):
248- """Build FAISS index from loaded CSV data."""
249- logger .info ("Building FAISS index from training data..." )
242+ def _connect_milvus (self ):
243+ """Connect to Milvus Lite."""
244+ try :
245+ logger .info (f"Connecting to Milvus Lite at { self .milvus_uri } " )
246+ self .client = MilvusClient (self .milvus_uri )
247+ logger .info ("Successfully connected to Milvus Lite" )
248+ except Exception as e :
249+ logger .error (f"Failed to connect to Milvus Lite: { e } " )
250+ raise
251+
252+ def _init_collection (self ):
253+ """Initialize or load Milvus collection."""
254+ # Check if collection exists
255+ if self .client .has_collection (self .collection_name ):
256+ logger .info (f"Loading existing collection: { self .collection_name } " )
257+
258+ # Check if we need to rebuild (verify count matches CSV)
259+ stats = self .client .get_collection_stats (self .collection_name )
260+ current_count = stats .get ("row_count" , 0 )
261+ expected_count = len (self .texts )
262+
263+ if current_count != expected_count :
264+ logger .warning (
265+ f"Collection has { current_count } entities but CSV has { expected_count } . Rebuilding..."
266+ )
267+ self .client .drop_collection (self .collection_name )
268+ self ._create_and_build_collection ()
269+ else :
270+ logger .info (f"Collection loaded with { current_count } vectors" )
271+ else :
272+ logger .info (
273+ f"Collection { self .collection_name } not found, creating new one..."
274+ )
275+ self ._create_and_build_collection ()
276+
277+ def _create_and_build_collection (self ):
278+ """Create Milvus collection and insert embeddings."""
279+ logger .info ("Creating new Milvus collection..." )
280+
281+ # Create collection with schema
282+ self .client .create_collection (
283+ collection_name = self .collection_name ,
284+ dimension = self .embedding_dim ,
285+ metric_type = "IP" , # Inner Product for cosine similarity
286+ auto_id = False ,
287+ primary_field_name = "id" ,
288+ vector_field_name = "embedding" ,
289+ )
290+ logger .info (f"Collection { self .collection_name } created" )
250291
251292 # Generate embeddings
252293 logger .info (f"Generating embeddings for { len (self .texts )} examples..." )
253294 embeddings = self ._encode_texts (self .texts )
254295
255296 # Normalize embeddings for cosine similarity
256- faiss .normalize_L2 (embeddings .astype ("float32" ))
257-
258- # Build FAISS index
259- logger .info (f"Creating FAISS index with dimension { self .embedding_dim } ..." )
260- self .index = faiss .IndexFlatIP (self .embedding_dim ) # Inner product for cosine
261- self .index .add (embeddings )
297+ norms = np .linalg .norm (embeddings , axis = 1 , keepdims = True )
298+ embeddings = embeddings / norms
299+
300+ # Prepare data for insertion
301+ data = []
302+ for i , (text , category , embedding ) in enumerate (
303+ zip (self .texts , self .categories , embeddings )
304+ ):
305+ data .append (
306+ {
307+ "id" : i ,
308+ "embedding" : embedding .tolist (),
309+ "category" : category ,
310+ "text" : text ,
311+ }
312+ )
262313
263- # Save index
264- logger .info (f"Saving index to { self .index_path } " )
265- faiss .write_index (self .index , self .index_path )
314+ # Insert data in batches
315+ batch_size = 100
316+ logger .info ("Inserting data into Milvus..." )
317+ for i in range (0 , len (data ), batch_size ):
318+ batch = data [i : i + batch_size ]
319+ self .client .insert (collection_name = self .collection_name , data = batch )
266320
267- logger .info (f"Index built successfully with { self . index . ntotal } vectors" )
321+ logger .info (f"Collection built successfully with { len ( data ) } vectors" )
268322
269323 def classify (
270324 self , text : str , k : int = 20 , with_probabilities : bool = False
@@ -282,14 +336,27 @@ def classify(
282336 """
283337 # Generate embedding for query text
284338 query_embedding = self ._encode_texts ([text ])
285- faiss .normalize_L2 (query_embedding .astype ("float32" ))
286339
287- # Search for k nearest neighbors
288- similarities , indices = self .index .search (query_embedding , k )
340+ # Normalize embeddings for cosine similarity
341+ norms = np .linalg .norm (query_embedding , axis = 1 , keepdims = True )
342+ query_embedding = query_embedding / norms
343+
344+ # Search for k nearest neighbors using Milvus
345+ results = self .client .search (
346+ collection_name = self .collection_name ,
347+ data = query_embedding .tolist (),
348+ limit = k ,
349+ output_fields = ["category" , "text" ],
350+ )
351+
352+ # Extract results
353+ neighbor_categories = []
354+ neighbor_similarities = []
289355
290- # Get categories of nearest neighbors
291- neighbor_categories = [self .categories [idx ] for idx in indices [0 ]]
292- neighbor_similarities = similarities [0 ]
356+ for hits in results :
357+ for hit in hits :
358+ neighbor_categories .append (hit .get ("entity" , {}).get ("category" ))
359+ neighbor_similarities .append (hit .get ("distance" , 0 ))
293360
294361 # Calculate confidence scores for each category
295362 category_scores = {cat : 0.0 for cat in self .category_names }
@@ -417,6 +484,7 @@ def _decide_routing(
417484# For multi-process deployments, each process gets its own instance.
418485classifier = None
419486classifier_device = "auto" # Default device setting
487+ classifier_milvus_uri = "./milvus_data.db" # Default Milvus Lite database path
420488
421489
422490def get_classifier ():
@@ -426,13 +494,14 @@ def get_classifier():
426494 # Get script directory
427495 script_dir = Path (__file__ ).parent
428496 csv_path = script_dir / "training_data.csv"
429- index_path = script_dir / "faiss_index.bin"
497+ milvus_uri = script_dir / classifier_milvus_uri
430498
431499 classifier = EmbeddingClassifier (
432500 model_name = "Qwen/Qwen3-Embedding-0.6B" ,
433501 csv_path = str (csv_path ),
434- index_path = str ( index_path ) ,
502+ collection_name = "embedding_classifier" ,
435503 device = classifier_device ,
504+ milvus_uri = str (milvus_uri ),
436505 )
437506 return classifier
438507
@@ -529,26 +598,32 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
529598 ]
530599
531600
532- async def main_stdio (device : str = "auto" ):
601+ async def main_stdio (device : str = "auto" , milvus_uri : str = "./milvus_data.db" ):
533602 """Run the MCP server in stdio mode."""
534- global classifier_device
603+ global classifier_device , classifier_milvus_uri
535604 classifier_device = device
605+ classifier_milvus_uri = milvus_uri
536606
537607 logger .info ("Starting Embedding-Based MCP Classification Server (stdio mode)" )
538608 clf = get_classifier ()
539609 logger .info (f"Available categories: { ', ' .join (clf .category_names )} " )
540610 logger .info (f"Model: { clf .model_name } " )
541611 logger .info (f"Device: { clf .device } " )
542- logger .info (f"Index size: { clf .index .ntotal } vectors" )
612+ logger .info (f"Milvus Lite: { clf .milvus_uri } " )
613+ stats = clf .client .get_collection_stats (clf .collection_name )
614+ logger .info (f"Collection size: { stats .get ('row_count' , 0 )} vectors" )
543615
544616 async with stdio_server () as (read_stream , write_stream ):
545617 await app .run (read_stream , write_stream , app .create_initialization_options ())
546618
547619
548- async def main_http (port : int = 8091 , device : str = "auto" ):
620+ async def main_http (
621+ port : int = 8091 , device : str = "auto" , milvus_uri : str = "./milvus_data.db"
622+ ):
549623 """Run the MCP server in HTTP mode."""
550- global classifier_device
624+ global classifier_device , classifier_milvus_uri
551625 classifier_device = device
626+ classifier_milvus_uri = milvus_uri
552627
553628 try :
554629 from aiohttp import web
@@ -563,7 +638,9 @@ async def main_http(port: int = 8091, device: str = "auto"):
563638 logger .info (f"Available categories: { ', ' .join (clf .category_names )} " )
564639 logger .info (f"Model: { clf .model_name } " )
565640 logger .info (f"Device: { clf .device } " )
566- logger .info (f"Index size: { clf .index .ntotal } vectors" )
641+ logger .info (f"Milvus Lite: { clf .milvus_uri } " )
642+ stats = clf .client .get_collection_stats (clf .collection_name )
643+ logger .info (f"Collection size: { stats .get ('row_count' , 0 )} vectors" )
567644 logger .info (f"Listening on http://0.0.0.0:{ port } /mcp" )
568645
569646 async def handle_mcp_request (request ):
@@ -672,12 +749,14 @@ async def handle_mcp_request(request):
672749 async def health_check (request ):
673750 """Health check endpoint."""
674751 clf = get_classifier ()
752+ stats = clf .client .get_collection_stats (clf .collection_name )
675753 return web .json_response (
676754 {
677755 "status" : "ok" ,
678756 "categories" : clf .category_names ,
679757 "model" : clf .model_name ,
680- "index_size" : clf .index .ntotal ,
758+ "collection_size" : stats .get ("row_count" , 0 ),
759+ "milvus_uri" : clf .milvus_uri ,
681760 }
682761 )
683762
@@ -724,7 +803,7 @@ async def health_check(request):
724803
725804 # Parse command line arguments
726805 parser = argparse .ArgumentParser (
727- description = "MCP Embedding-Based Classification Server"
806+ description = "MCP Embedding-Based Classification Server (Milvus Lite) "
728807 )
729808 parser .add_argument (
730809 "--http" , action = "store_true" , help = "Run in HTTP mode instead of stdio"
@@ -737,9 +816,15 @@ async def health_check(request):
737816 choices = ["auto" , "cuda" , "cpu" ],
738817 help = "Device to use for inference (auto=auto-detect, cuda=force GPU, cpu=force CPU)" ,
739818 )
819+ parser .add_argument (
820+ "--milvus-uri" ,
821+ type = str ,
822+ default = "./milvus_data.db" ,
823+ help = "Milvus Lite database file path (default: ./milvus_data.db)" ,
824+ )
740825 args = parser .parse_args ()
741826
742827 if args .http :
743- asyncio .run (main_http (args .port , args .device ))
828+ asyncio .run (main_http (args .port , args .device , args . milvus_uri ))
744829 else :
745- asyncio .run (main_stdio (args .device ))
830+ asyncio .run (main_stdio (args .device , args . milvus_uri ))
0 commit comments