-
Notifications
You must be signed in to change notification settings - Fork 180
Description
Is your feature request related to a problem? Please describe.
The current semantic router only supports routing based on fine-tuned category models (ModernBERT) with a fixed set of 14 predefined categories (business, law, psychology, biology, chemistry, history, other, health, economics, math, physics, computer science, philosophy, engineering). This approach has significant limitations:
- Limited Category Coverage: Users cannot define custom categories beyond the 14 pre-trained ones, restricting flexibility for domain-specific use cases (e.g., travel, legal consulting, sports news)
- Dataset Dependency: Adding new categories requires retraining the ModernBERT classifier with labeled datasets, which is time-consuming and resource-intensive
- Inflexibility: Organizations with unique routing requirements cannot adapt the system to their specific needs without model retraining
- Scalability Issues: As business requirements evolve, the fixed category set becomes a bottleneck for rapid iteration
- No Zero-Shot Capability: Cannot add new categories on-the-fly without model retraining
Real-World Scenario:
A travel company wants to route travel-related queries to specialized travel models:
- Query: "Recommend a 3-day itinerary for Paris with museums and cafes"
- Expected: Match custom category "travel-planning" → Route to [travel-expert, general-assistant]
- Current: Classify as "other" (no specific category) → Route to general models
This results in:
- Poor routing accuracy (category too generic)
- Cannot leverage domain-specific models
- Requires expensive model retraining to add "travel" category
Describe the solution you'd like
Implement a similarity-based custom category routing mechanism that allows users to dynamically define categories using natural language descriptions and examples, without requiring model retraining. The solution should enable:
Core Concept: Define custom categories with natural language descriptions and examples, then use BERT embeddings to match queries semantically:
custom_categories:
enabled: true
similarity_threshold: 0.75
categories:
- id: "travel-planning"
description: "Travel planning, destination recommendations, itineraries"
examples:
- "Recommend a 3-day itinerary for Paris"
- "Best time to visit Japan"
candidate_models:
- "travel-expert"
- "general-assistant"
Key Features:
- Zero-Shot Capability: Add new categories without retraining models
- Semantic Matching: Use BERT embeddings for intelligent category matching
- Confidence Validation: Threshold and gap checks to ensure confident matches
- Fallback Support: Automatically fallback to fine-tuned classifier when needed
- Backward Compatible: Existing ModernBERT routing continues to work
Architecture Overview
High-Level Flow:
flowchart TD
A[User Query] --> B{custom_categories.enabled?}
B -->|No| C[ModernBERT Classifier]
B -->|Yes| D[Generate Query Embedding]
D --> E[Calculate Similarity with Categories]
E --> F{Top-1 Score >= Threshold?}
F -->|No| C
F -->|Yes| G{Gap >= Gap Threshold?}
G -->|No| C[Fallback: ModernBERT]
G -->|Yes| H[Match Custom Category]
H --> I[Get Candidate Models]
C --> J[Get Category Models]
I --> K[Select Best Model]
J --> K
K --> L[Return Selected Model]
style A fill:#e1f5ff
style L fill:#c8e6c9
style D fill:#ffe0b2
style H fill:#c8e6c9
style C fill:#f8bbd0
Detailed Similarity Matching Process:
sequenceDiagram
participant Query
participant SM as SimilarityMatcher
participant Cache as EmbeddingCache
participant BERT as BERT Model
participant Validator
Query->>SM: "Recommend Paris itinerary"
SM->>BERT: Generate query embedding
BERT-->>SM: [0.12, -0.45, 0.78, ...] (384-dim)
SM->>Cache: Get category embeddings
Cache-->>SM: Pre-computed embeddings
SM->>SM: Calculate cosine similarity
Note over SM: travel-planning: 0.88<br/>legal-consulting: 0.42<br/>tech-support: 0.35
SM->>Validator: Validate top match
Validator->>Validator: Check threshold: 0.88 >= 0.75 ✓
Validator->>Validator: Check gap: 0.88 - 0.42 = 0.46 >= 0.05 ✓
Validator-->>SM: Confident match ✓
SM-->>Query: Category: travel-planning<br/>Models: [travel-expert, general-assistant]
1. Dual-Mode Routing Architecture
Extend the existing config-based routing (not the SemanticRoute API, which is not yet implemented) to support two routing modes:
- Fine-tuned Mode (existing): Uses ModernBERT classifier for predefined 14 categories
- Similarity Mode (new): Uses embedding-based semantic similarity for user-defined custom categories
2. Custom Category Definition
Allow users to define custom categories in the configuration with:
custom_categories:
enabled: true
similarity_threshold: 0.75 # Minimum cosine similarity for matching
gap_threshold: 0.05 # Minimum gap between top-1 and top-2 to avoid ambiguity
categories:
- id: "travel"
name: "Travel & Tourism"
description: "Queries related to travel planning, destination recommendations, visa requirements, and tourism information"
examples:
- "Recommend a 3-day itinerary for Paris"
- "What documents do I need for a Schengen visa?"
- "Best time to visit Japan for cherry blossoms"
model_scores:
- model: openai/gpt-oss-20b
score: 0.8
use_reasoning: false
- id: "legal_consulting"
name: "Legal Consultation"
description: "Questions about laws, regulations, and legal procedures"
examples:
- "What are the latest changes in civil law?"
- "How to apply for legal aid?"
model_scores:
- model: openai/gpt-oss-20b
score: 0.9
use_reasoning: true
3. Embedding-Based Matching Pipeline
Implement a semantic similarity matching system:
-
Initialization Phase:
- Load embedding model (default:
sentence-transformers/all-MiniLM-L12-v2
from configbert_model
) - Generate and cache normalized embeddings for all custom category descriptions and examples
- Store category vectors in memory for fast retrieval
- Load embedding model (default:
-
Query Processing Phase:
- Generate embedding for incoming user query
- Calculate cosine similarity between query embedding and all custom category embeddings
- Rank categories by similarity score
- Apply threshold and gap validation to ensure confident matching
-
Decision Logic:
IF custom_categories.enabled AND custom categories exist: query_embedding = embedding_model.encode(query) similarities = calculate_cosine_similarity(query_embedding, all_category_embeddings) top1_category, top1_score = get_top_match(similarities) top2_score = get_second_match(similarities) IF top1_score >= similarity_threshold AND (top1_score - top2_score) >= gap_threshold: RETURN route_to_custom_category(top1_category) ELSE: FALLBACK to ModernBERT classification ELSE: USE ModernBERT classification (existing behavior)
4. Fallback Mechanism
Ensure robust fallback to the existing ModernBERT classifier when:
- No custom categories are defined
- Similarity score is below threshold
- Top-2 categories have ambiguous scores (gap too small)
- Embedding model fails to initialize
This guarantees backward compatibility and system reliability.
5. Model Selection Integration
Once a category (custom or predefined) is determined, use the existing model selection logic:
- Retrieve
model_scores
from the matched category - Select the best model based on score and TTFT (Time To First Token)
- Apply reasoning mode if configured (
use_reasoning: true
)
6. Configuration Refactoring
Refactor the existing config structure to support both modes:
# Existing bert_model config (used for similarity-based routing)
bert_model:
model_id: sentence-transformers/all-MiniLM-L12-v2
threshold: 0.6
use_cpu: true
# Existing classifier config (used for fine-tuned routing)
classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
# NEW: Custom categories with similarity-based routing
custom_categories:
enabled: true
similarity_threshold: 0.75
gap_threshold: 0.05
categories:
- id: "custom_category_1"
name: "Category Name"
description: "Detailed description"
examples: ["example1", "example2"]
model_scores:
- model: model_name
score: 0.8
use_reasoning: false
# Existing predefined categories (used when fallback to ModernBERT)
categories:
- name: business
model_scores: [...]
- name: math
model_scores: [...]
# ... other 12 categories
7. Implementation Components
1. SimilarityMatcher (pkg/utils/similarity/matcher.go
- NEW)
package similarity
type SimilarityMatcher struct {
categories []CustomCategory
categoryEmbeddings map[string][]float32 // Pre-computed embeddings
similarityThreshold float32
gapThreshold float32
}
type CustomCategory struct {
ID string
Name string
Description string
Examples []string
CandidateModels []string
}
type SimilarityMatchResult struct {
CategoryID string
CategoryName string
Similarity float32
Gap float32
Confident bool
CandidateModels []string
}
// NewSimilarityMatcher creates a new similarity matcher
func NewSimilarityMatcher(config CustomCategoriesConfig) *SimilarityMatcher {
return &SimilarityMatcher{
categories: config.Categories,
categoryEmbeddings: make(map[string][]float32),
similarityThreshold: config.SimilarityThreshold,
gapThreshold: config.GapThreshold,
}
}
// InitializeEmbeddings pre-computes and caches category embeddings
func (m *SimilarityMatcher) InitializeEmbeddings() error {
for _, category := range m.categories {
// Combine description and examples for better matching
text := category.Description
for _, example := range category.Examples {
text += " " + example
}
// Generate embedding using BERT model
embedding, err := candle.GetTextEmbedding(text)
if err != nil {
return fmt.Errorf("failed to generate embedding for category %s: %w", category.ID, err)
}
// Normalize embedding for cosine similarity
normalized := normalizeVector(embedding)
m.categoryEmbeddings[category.ID] = normalized
observability.Infof("Initialized embedding for category: %s (dim: %d)", category.ID, len(normalized))
}
return nil
}
// MatchQuery matches the query against all custom categories
func (m *SimilarityMatcher) MatchQuery(query string) (*SimilarityMatchResult, error) {
// Generate query embedding
queryEmbedding, err := candle.GetTextEmbedding(query)
if err != nil {
return nil, fmt.Errorf("failed to generate query embedding: %w", err)
}
// Normalize query embedding
normalizedQuery := normalizeVector(queryEmbedding)
// Calculate similarity with all categories
similarities := make([]struct {
categoryID string
score float32
}, 0, len(m.categories))
for _, category := range m.categories {
categoryEmb := m.categoryEmbeddings[category.ID]
similarity := cosineSimilarity(normalizedQuery, categoryEmb)
similarities = append(similarities, struct {
categoryID string
score float32
}{category.ID, similarity})
}
// Sort by similarity (descending)
sort.Slice(similarities, func(i, j int) bool {
return similarities[i].score > similarities[j].score
})
// Get top-2 matches
if len(similarities) == 0 {
return nil, fmt.Errorf("no categories available")
}
top1 := similarities[0]
top2Score := float32(0.0)
if len(similarities) > 1 {
top2Score = similarities[1].score
}
gap := top1.score - top2Score
// Validate confidence
confident := top1.score >= m.similarityThreshold && gap >= m.gapThreshold
// Find category details
var matchedCategory *CustomCategory
for i := range m.categories {
if m.categories[i].ID == top1.categoryID {
matchedCategory = &m.categories[i]
break
}
}
if matchedCategory == nil {
return nil, fmt.Errorf("category not found: %s", top1.categoryID)
}
return &SimilarityMatchResult{
CategoryID: matchedCategory.ID,
CategoryName: matchedCategory.Name,
Similarity: top1.score,
Gap: gap,
Confident: confident,
CandidateModels: matchedCategory.CandidateModels,
}, nil
}
// Helper functions
func normalizeVector(vec []float32) []float32 {
var norm float32
for _, v := range vec {
norm += v * v
}
norm = float32(math.Sqrt(float64(norm)))
normalized := make([]float32, len(vec))
for i, v := range vec {
normalized[i] = v / norm
}
return normalized
}
func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) {
return 0.0
}
var dotProduct float32
for i := range a {
dotProduct += a[i] * b[i]
}
return dotProduct // Already normalized, so dot product = cosine similarity
}
2. Configuration Extension (pkg/config/config.go
)
type RouterConfig struct {
// ... existing fields ...
CustomCategories CustomCategoriesConfig `yaml:"custom_categories"`
}
type CustomCategoriesConfig struct {
Enabled bool `yaml:"enabled"`
SimilarityThreshold float32 `yaml:"similarity_threshold"`
GapThreshold float32 `yaml:"gap_threshold"`
Categories []CustomCategory `yaml:"categories"`
}
type CustomCategory struct {
ID string `yaml:"id"`
Name string `yaml:"name"`
Description string `yaml:"description"`
Examples []string `yaml:"examples"`
CandidateModels []string `yaml:"candidate_models"`
}
3. Router Integration (pkg/extproc/request_handler.go
)
func (r *OpenAIRouter) handleModelRouting(...) (*ext_proc.ProcessingResponse, error) {
// ... existing code ...
if originalModel == "auto" {
var selectedModel string
// Try similarity-based custom category matching first
if r.Config.CustomCategories.Enabled && r.SimilarityMatcher != nil {
matchResult, err := r.SimilarityMatcher.MatchQuery(userContent)
if err == nil && matchResult.Confident {
// Confident match with custom category
observability.Infof("Similarity match: category=%s, similarity=%.3f, gap=%.3f",
matchResult.CategoryName, matchResult.Similarity, matchResult.Gap)
// Select best model from candidates
selectedModel = r.Classifier.SelectBestModelFromList(userContent, matchResult.CandidateModels)
// Record metrics
metrics.RecordSimilarityRouting(matchResult, selectedModel)
} else {
// Fallback to ModernBERT classifier
if err != nil {
observability.Warnf("Similarity matching failed: %v, falling back to classifier", err)
} else {
observability.Infof("Low confidence match (similarity=%.3f, gap=%.3f), falling back to classifier",
matchResult.Similarity, matchResult.Gap)
}
selectedModel = r.classifyAndSelectBestModel(userContent)
metrics.RecordSimilarityFallback()
}
} else {
// Use existing category-only routing
selectedModel = r.classifyAndSelectBestModel(userContent)
}
matchedModel = selectedModel
}
// ... rest of the code ...
}
4. Router Initialization (pkg/extproc/router.go
)
func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
// ... existing initialization code ...
// Initialize similarity matcher if custom categories are enabled
var similarityMatcher *similarity.SimilarityMatcher
if cfg.CustomCategories.Enabled {
similarityMatcher = similarity.NewSimilarityMatcher(cfg.CustomCategories)
// Pre-compute and cache category embeddings
if err := similarityMatcher.InitializeEmbeddings(); err != nil {
return nil, fmt.Errorf("failed to initialize similarity matcher: %w", err)
}
observability.Infof("Initialized similarity matcher with %d custom categories",
len(cfg.CustomCategories.Categories))
}
router := &OpenAIRouter{
Config: cfg,
Classifier: classifier,
SimilarityMatcher: similarityMatcher,
// ... other fields ...
}
return router, nil
}
8. Observability & Monitoring
Add comprehensive logging and metrics:
- Log similarity scores for top-K categories
- Track custom category hit rate vs. fallback rate
- Monitor embedding generation latency
- Record ambiguous matches (small gap between top-2) for tuning
- Expose Prometheus metrics for routing decisions
Additional context
Design Reference: The detailed design is documented in customize-category.docx
, which includes:
- Evaluation of
all-MiniLM-L12-v2
embedding model suitability - Alternative high-performance embedding models (BGE, M3E, multilingual models)
- Best practices for category organization and maintenance
- Similarity matching strategies and threshold tuning
- Integration patterns and fallback mechanisms
Existing Infrastructure:
- BERT similarity model is already initialized in the system (
bert_model
config) - Embedding generation functions exist in
candle-binding
(get_text_embedding
,calculate_similarity
) - Tool selection already uses similarity-based matching (
src/semantic-router/pkg/tools/tools.go
) - Config-based routing is the primary mechanism (SemanticRoute API examples exist but are not implemented)
Benefits:
- Zero-Shot Capability: Add new categories without retraining models
- Rapid Iteration: Update category definitions in minutes vs. hours/days for retraining
- Domain Flexibility: Support unlimited custom categories for diverse use cases
- Cost Efficiency: Eliminate expensive model retraining cycles
- Backward Compatible: Existing ModernBERT routing remains functional as fallback
Technical Considerations:
- Embedding model should be lightweight for low latency (all-MiniLM-L12-v2: 384-dim, ~120MB)
- Category embeddings should be pre-computed and cached in memory
- Cosine similarity calculation is efficient for <100 categories (brute-force acceptable)
- For >1000 categories, consider ANN libraries (FAISS, HNSWlib) for optimization
- Ensure embedding space consistency (same model for queries and categories)
Use Cases:
- Travel & Tourism: Route travel queries to specialized travel models without retraining
- Legal Consulting: Add legal domain categories for law firms
- Healthcare: Create medical specialty categories (cardiology, neurology, etc.)
- E-commerce: Product recommendation, customer support, order tracking categories
- Education: Course-specific categories for educational platforms
Performance Characteristics:
Operation | Latency | Scalability |
---|---|---|
Embedding Generation | ~10-15ms | O(1) per query |
Similarity Calculation | ~5-10ms | O(n) where n=categories |
Total (Similarity Mode) | ~15-25ms | Efficient for <100 categories |
ModernBERT Classification | ~20-30ms | O(1) model inference |
Comparison with Fine-Tuned Classifier:
graph LR
subgraph "Fine-Tuned Classifier"
A1[Query] --> B1[ModernBERT Inference]
B1 --> C1[14 Fixed Categories]
C1 --> D1[Model Selection]
D1 --> E1[Result]
style B1 fill:#f8bbd0
style C1 fill:#ffcdd2
end
subgraph "Similarity-Based Matcher"
A2[Query] --> B2[BERT Embedding]
B2 --> C2[Cosine Similarity]
C2 --> D2[Unlimited Custom Categories]
D2 --> E2[Model Selection]
E2 --> F2[Result]
style B2 fill:#ffe0b2
style D2 fill:#c8e6c9
end
Testing Requirements:
-
Unit Tests:
- Embedding generation and normalization
- Cosine similarity calculation
- Threshold validation (above/below threshold)
- Gap validation (confident/ambiguous matches)
- Fallback logic when no confident match
- Category embedding cache management
-
Integration Tests:
- Similarity matching with confident match
- Similarity matching with low confidence → fallback to classifier
- Similarity matching with small gap → fallback to classifier
- Disabled custom categories → use classifier only
- Multiple categories with different similarity scores
-
E2E Tests:
- Real queries with expected category matches
- Performance benchmarks (latency, throughput)
- Stress tests with many custom categories
-
Configuration Tests:
- Valid custom category configuration
- Invalid configuration handling (missing fields, invalid thresholds)
- Dynamic category updates
Implementation Phases:
Phase 1: Core Implementation (2-3 weeks)
- Implement SimilarityMatcher with embedding generation
- Add configuration structures
- Implement cosine similarity calculation
- Add unit tests for similarity matching
Phase 2: Integration (1-2 weeks)
- Integrate with OpenAIRouter
- Implement fallback logic to ModernBERT classifier
- Add integration tests
Phase 3: Observability (1 week)
- Add metrics for similarity matching
- Add detailed logging
- Add performance benchmarks
- Add threshold tuning guidelines
Observability & Metrics:
# New metrics for similarity-based routing
similarity_routing_matches_total{category_id, confident} # Counter of category matches
similarity_routing_score_histogram{category_id} # Histogram of similarity scores
similarity_routing_gap_histogram{category_id} # Histogram of gaps (top-1 vs top-2)
similarity_routing_duration_seconds{stage} # Histogram of latency by stage
similarity_routing_fallback_total{reason} # Counter of fallbacks (low_score, small_gap, error)
similarity_embedding_cache_size # Gauge of cached embeddings
Example Logging:
[INFO] Custom categories enabled with 5 categories
[INFO] Initializing embeddings for category: travel-planning
[INFO] Initializing embeddings for category: legal-consulting
[INFO] Initializing embeddings for category: tech-support
[INFO] Initializing embeddings for category: healthcare-advice
[INFO] Initializing embeddings for category: financial-planning
[INFO] Similarity matcher initialized successfully
[INFO]
[INFO] Query: "Recommend a 3-day itinerary for Paris with museums"
[INFO] Generated query embedding (384 dimensions)
[INFO] Similarity scores:
[INFO] - travel-planning: 0.88 (gap: 0.46)
[INFO] - legal-consulting: 0.42
[INFO] - tech-support: 0.35
[INFO] - healthcare-advice: 0.28
[INFO] - financial-planning: 0.25
[INFO] Confident match: travel-planning (similarity: 0.88 >= 0.75, gap: 0.46 >= 0.05)
[INFO] Candidate models: [travel-expert, general-assistant]
[INFO] Selected model: travel-expert (score: 0.90)
[INFO] Similarity routing latency: 18ms
Future Enhancements:
- Support multiple embedding models (BGE for Chinese, MPNet for higher accuracy)
- Implement category embedding versioning for model upgrades
- Add A/B testing framework for similarity threshold tuning
- Support hierarchical categories with parent-child relationships
- Integrate with vector databases (Milvus, Qdrant) for large-scale deployments (>1000 categories)
- Add category similarity visualization for debugging
- Support category aliases and synonyms
Related Files:
src/semantic-router/pkg/config/config.go
- Configuration structuressrc/semantic-router/pkg/extproc/request_handler.go
- Request routing logicsrc/semantic-router/pkg/utils/similarity/matcher.go
- Similarity matcher (NEW)candle-binding/semantic-router.go
- BERT embedding functionsconfig/config.yaml
- Configuration example
Related Issues:
- Issue Support Similarity-Based Custom Category Routing for Dynamic Model Selection #312: Support Similarity-Based Custom Category Routing for Dynamic Model Selection (this issue)
/area core
/milestone v0.1
/priority P0