Skip to content

Support Similarity-Based Custom Category Routing for Dynamic Model Selection #312

@Xunzhuo

Description

@Xunzhuo

Is your feature request related to a problem? Please describe.

The current semantic router only supports routing based on fine-tuned category models (ModernBERT) with a fixed set of 14 predefined categories (business, law, psychology, biology, chemistry, history, other, health, economics, math, physics, computer science, philosophy, engineering). This approach has significant limitations:

  1. Limited Category Coverage: Users cannot define custom categories beyond the 14 pre-trained ones, restricting flexibility for domain-specific use cases (e.g., travel, legal consulting, sports news)
  2. Dataset Dependency: Adding new categories requires retraining the ModernBERT classifier with labeled datasets, which is time-consuming and resource-intensive
  3. Inflexibility: Organizations with unique routing requirements cannot adapt the system to their specific needs without model retraining
  4. Scalability Issues: As business requirements evolve, the fixed category set becomes a bottleneck for rapid iteration
  5. No Zero-Shot Capability: Cannot add new categories on-the-fly without model retraining

Real-World Scenario:
A travel company wants to route travel-related queries to specialized travel models:

  • Query: "Recommend a 3-day itinerary for Paris with museums and cafes"
  • Expected: Match custom category "travel-planning" → Route to [travel-expert, general-assistant]
  • Current: Classify as "other" (no specific category) → Route to general models

This results in:

  • Poor routing accuracy (category too generic)
  • Cannot leverage domain-specific models
  • Requires expensive model retraining to add "travel" category

Describe the solution you'd like

Implement a similarity-based custom category routing mechanism that allows users to dynamically define categories using natural language descriptions and examples, without requiring model retraining. The solution should enable:

Core Concept: Define custom categories with natural language descriptions and examples, then use BERT embeddings to match queries semantically:

custom_categories:
  enabled: true
  similarity_threshold: 0.75
  categories:
    - id: "travel-planning"
      description: "Travel planning, destination recommendations, itineraries"
      examples:
        - "Recommend a 3-day itinerary for Paris"
        - "Best time to visit Japan"
      candidate_models:
        - "travel-expert"
        - "general-assistant"

Key Features:

  1. Zero-Shot Capability: Add new categories without retraining models
  2. Semantic Matching: Use BERT embeddings for intelligent category matching
  3. Confidence Validation: Threshold and gap checks to ensure confident matches
  4. Fallback Support: Automatically fallback to fine-tuned classifier when needed
  5. Backward Compatible: Existing ModernBERT routing continues to work

Architecture Overview

High-Level Flow:

flowchart TD
    A[User Query] --> B{custom_categories.enabled?}
    B -->|No| C[ModernBERT Classifier]
    B -->|Yes| D[Generate Query Embedding]

    D --> E[Calculate Similarity with Categories]
    E --> F{Top-1 Score >= Threshold?}

    F -->|No| C
    F -->|Yes| G{Gap >= Gap Threshold?}

    G -->|No| C[Fallback: ModernBERT]
    G -->|Yes| H[Match Custom Category]

    H --> I[Get Candidate Models]
    C --> J[Get Category Models]

    I --> K[Select Best Model]
    J --> K

    K --> L[Return Selected Model]

    style A fill:#e1f5ff
    style L fill:#c8e6c9
    style D fill:#ffe0b2
    style H fill:#c8e6c9
    style C fill:#f8bbd0
Loading

Detailed Similarity Matching Process:

sequenceDiagram
    participant Query
    participant SM as SimilarityMatcher
    participant Cache as EmbeddingCache
    participant BERT as BERT Model
    participant Validator

    Query->>SM: "Recommend Paris itinerary"
    SM->>BERT: Generate query embedding
    BERT-->>SM: [0.12, -0.45, 0.78, ...] (384-dim)

    SM->>Cache: Get category embeddings
    Cache-->>SM: Pre-computed embeddings

    SM->>SM: Calculate cosine similarity
    Note over SM: travel-planning: 0.88<br/>legal-consulting: 0.42<br/>tech-support: 0.35

    SM->>Validator: Validate top match
    Validator->>Validator: Check threshold: 0.88 >= 0.75 ✓
    Validator->>Validator: Check gap: 0.88 - 0.42 = 0.46 >= 0.05 ✓
    Validator-->>SM: Confident match ✓

    SM-->>Query: Category: travel-planning<br/>Models: [travel-expert, general-assistant]
Loading

1. Dual-Mode Routing Architecture

Extend the existing config-based routing (not the SemanticRoute API, which is not yet implemented) to support two routing modes:

  • Fine-tuned Mode (existing): Uses ModernBERT classifier for predefined 14 categories
  • Similarity Mode (new): Uses embedding-based semantic similarity for user-defined custom categories

2. Custom Category Definition

Allow users to define custom categories in the configuration with:

custom_categories:
  enabled: true
  similarity_threshold: 0.75  # Minimum cosine similarity for matching
  gap_threshold: 0.05         # Minimum gap between top-1 and top-2 to avoid ambiguity
  
  categories:
    - id: "travel"
      name: "Travel & Tourism"
      description: "Queries related to travel planning, destination recommendations, visa requirements, and tourism information"
      examples:
        - "Recommend a 3-day itinerary for Paris"
        - "What documents do I need for a Schengen visa?"
        - "Best time to visit Japan for cherry blossoms"
      model_scores:
        - model: openai/gpt-oss-20b
          score: 0.8
          use_reasoning: false
    
    - id: "legal_consulting"
      name: "Legal Consultation"
      description: "Questions about laws, regulations, and legal procedures"
      examples:
        - "What are the latest changes in civil law?"
        - "How to apply for legal aid?"
      model_scores:
        - model: openai/gpt-oss-20b
          score: 0.9
          use_reasoning: true

3. Embedding-Based Matching Pipeline

Implement a semantic similarity matching system:

  1. Initialization Phase:

    • Load embedding model (default: sentence-transformers/all-MiniLM-L12-v2 from config bert_model)
    • Generate and cache normalized embeddings for all custom category descriptions and examples
    • Store category vectors in memory for fast retrieval
  2. Query Processing Phase:

    • Generate embedding for incoming user query
    • Calculate cosine similarity between query embedding and all custom category embeddings
    • Rank categories by similarity score
    • Apply threshold and gap validation to ensure confident matching
  3. Decision Logic:

    IF custom_categories.enabled AND custom categories exist:
        query_embedding = embedding_model.encode(query)
        similarities = calculate_cosine_similarity(query_embedding, all_category_embeddings)
        top1_category, top1_score = get_top_match(similarities)
        top2_score = get_second_match(similarities)
        
        IF top1_score >= similarity_threshold AND (top1_score - top2_score) >= gap_threshold:
            RETURN route_to_custom_category(top1_category)
        ELSE:
            FALLBACK to ModernBERT classification
    ELSE:
        USE ModernBERT classification (existing behavior)
    

4. Fallback Mechanism

Ensure robust fallback to the existing ModernBERT classifier when:

  • No custom categories are defined
  • Similarity score is below threshold
  • Top-2 categories have ambiguous scores (gap too small)
  • Embedding model fails to initialize

This guarantees backward compatibility and system reliability.

5. Model Selection Integration

Once a category (custom or predefined) is determined, use the existing model selection logic:

  • Retrieve model_scores from the matched category
  • Select the best model based on score and TTFT (Time To First Token)
  • Apply reasoning mode if configured (use_reasoning: true)

6. Configuration Refactoring

Refactor the existing config structure to support both modes:

# Existing bert_model config (used for similarity-based routing)
bert_model:
  model_id: sentence-transformers/all-MiniLM-L12-v2
  threshold: 0.6
  use_cpu: true

# Existing classifier config (used for fine-tuned routing)
classifier:
  category_model:
    model_id: "models/category_classifier_modernbert-base_model"
    use_modernbert: true
    threshold: 0.6
    use_cpu: true
    category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"

# NEW: Custom categories with similarity-based routing
custom_categories:
  enabled: true
  similarity_threshold: 0.75
  gap_threshold: 0.05
  categories:
    - id: "custom_category_1"
      name: "Category Name"
      description: "Detailed description"
      examples: ["example1", "example2"]
      model_scores:
        - model: model_name
          score: 0.8
          use_reasoning: false

# Existing predefined categories (used when fallback to ModernBERT)
categories:
  - name: business
    model_scores: [...]
  - name: math
    model_scores: [...]
  # ... other 12 categories

7. Implementation Components

1. SimilarityMatcher (pkg/utils/similarity/matcher.go - NEW)

package similarity

type SimilarityMatcher struct {
    categories          []CustomCategory
    categoryEmbeddings  map[string][]float32  // Pre-computed embeddings
    similarityThreshold float32
    gapThreshold        float32
}

type CustomCategory struct {
    ID              string
    Name            string
    Description     string
    Examples        []string
    CandidateModels []string
}

type SimilarityMatchResult struct {
    CategoryID   string
    CategoryName string
    Similarity   float32
    Gap          float32
    Confident    bool
    CandidateModels []string
}

// NewSimilarityMatcher creates a new similarity matcher
func NewSimilarityMatcher(config CustomCategoriesConfig) *SimilarityMatcher {
    return &SimilarityMatcher{
        categories:          config.Categories,
        categoryEmbeddings:  make(map[string][]float32),
        similarityThreshold: config.SimilarityThreshold,
        gapThreshold:        config.GapThreshold,
    }
}

// InitializeEmbeddings pre-computes and caches category embeddings
func (m *SimilarityMatcher) InitializeEmbeddings() error {
    for _, category := range m.categories {
        // Combine description and examples for better matching
        text := category.Description
        for _, example := range category.Examples {
            text += " " + example
        }

        // Generate embedding using BERT model
        embedding, err := candle.GetTextEmbedding(text)
        if err != nil {
            return fmt.Errorf("failed to generate embedding for category %s: %w", category.ID, err)
        }

        // Normalize embedding for cosine similarity
        normalized := normalizeVector(embedding)
        m.categoryEmbeddings[category.ID] = normalized

        observability.Infof("Initialized embedding for category: %s (dim: %d)", category.ID, len(normalized))
    }

    return nil
}

// MatchQuery matches the query against all custom categories
func (m *SimilarityMatcher) MatchQuery(query string) (*SimilarityMatchResult, error) {
    // Generate query embedding
    queryEmbedding, err := candle.GetTextEmbedding(query)
    if err != nil {
        return nil, fmt.Errorf("failed to generate query embedding: %w", err)
    }

    // Normalize query embedding
    normalizedQuery := normalizeVector(queryEmbedding)

    // Calculate similarity with all categories
    similarities := make([]struct {
        categoryID string
        score      float32
    }, 0, len(m.categories))

    for _, category := range m.categories {
        categoryEmb := m.categoryEmbeddings[category.ID]
        similarity := cosineSimilarity(normalizedQuery, categoryEmb)
        similarities = append(similarities, struct {
            categoryID string
            score      float32
        }{category.ID, similarity})
    }

    // Sort by similarity (descending)
    sort.Slice(similarities, func(i, j int) bool {
        return similarities[i].score > similarities[j].score
    })

    // Get top-2 matches
    if len(similarities) == 0 {
        return nil, fmt.Errorf("no categories available")
    }

    top1 := similarities[0]
    top2Score := float32(0.0)
    if len(similarities) > 1 {
        top2Score = similarities[1].score
    }

    gap := top1.score - top2Score

    // Validate confidence
    confident := top1.score >= m.similarityThreshold && gap >= m.gapThreshold

    // Find category details
    var matchedCategory *CustomCategory
    for i := range m.categories {
        if m.categories[i].ID == top1.categoryID {
            matchedCategory = &m.categories[i]
            break
        }
    }

    if matchedCategory == nil {
        return nil, fmt.Errorf("category not found: %s", top1.categoryID)
    }

    return &SimilarityMatchResult{
        CategoryID:      matchedCategory.ID,
        CategoryName:    matchedCategory.Name,
        Similarity:      top1.score,
        Gap:             gap,
        Confident:       confident,
        CandidateModels: matchedCategory.CandidateModels,
    }, nil
}

// Helper functions
func normalizeVector(vec []float32) []float32 {
    var norm float32
    for _, v := range vec {
        norm += v * v
    }
    norm = float32(math.Sqrt(float64(norm)))

    normalized := make([]float32, len(vec))
    for i, v := range vec {
        normalized[i] = v / norm
    }
    return normalized
}

func cosineSimilarity(a, b []float32) float32 {
    if len(a) != len(b) {
        return 0.0
    }

    var dotProduct float32
    for i := range a {
        dotProduct += a[i] * b[i]
    }

    return dotProduct  // Already normalized, so dot product = cosine similarity
}

2. Configuration Extension (pkg/config/config.go)

type RouterConfig struct {
    // ... existing fields ...
    CustomCategories CustomCategoriesConfig `yaml:"custom_categories"`
}

type CustomCategoriesConfig struct {
    Enabled             bool             `yaml:"enabled"`
    SimilarityThreshold float32          `yaml:"similarity_threshold"`
    GapThreshold        float32          `yaml:"gap_threshold"`
    Categories          []CustomCategory `yaml:"categories"`
}

type CustomCategory struct {
    ID              string   `yaml:"id"`
    Name            string   `yaml:"name"`
    Description     string   `yaml:"description"`
    Examples        []string `yaml:"examples"`
    CandidateModels []string `yaml:"candidate_models"`
}

3. Router Integration (pkg/extproc/request_handler.go)

func (r *OpenAIRouter) handleModelRouting(...) (*ext_proc.ProcessingResponse, error) {
    // ... existing code ...

    if originalModel == "auto" {
        var selectedModel string

        // Try similarity-based custom category matching first
        if r.Config.CustomCategories.Enabled && r.SimilarityMatcher != nil {
            matchResult, err := r.SimilarityMatcher.MatchQuery(userContent)

            if err == nil && matchResult.Confident {
                // Confident match with custom category
                observability.Infof("Similarity match: category=%s, similarity=%.3f, gap=%.3f",
                    matchResult.CategoryName, matchResult.Similarity, matchResult.Gap)

                // Select best model from candidates
                selectedModel = r.Classifier.SelectBestModelFromList(userContent, matchResult.CandidateModels)

                // Record metrics
                metrics.RecordSimilarityRouting(matchResult, selectedModel)
            } else {
                // Fallback to ModernBERT classifier
                if err != nil {
                    observability.Warnf("Similarity matching failed: %v, falling back to classifier", err)
                } else {
                    observability.Infof("Low confidence match (similarity=%.3f, gap=%.3f), falling back to classifier",
                        matchResult.Similarity, matchResult.Gap)
                }

                selectedModel = r.classifyAndSelectBestModel(userContent)
                metrics.RecordSimilarityFallback()
            }
        } else {
            // Use existing category-only routing
            selectedModel = r.classifyAndSelectBestModel(userContent)
        }

        matchedModel = selectedModel
    }

    // ... rest of the code ...
}

4. Router Initialization (pkg/extproc/router.go)

func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
    // ... existing initialization code ...

    // Initialize similarity matcher if custom categories are enabled
    var similarityMatcher *similarity.SimilarityMatcher
    if cfg.CustomCategories.Enabled {
        similarityMatcher = similarity.NewSimilarityMatcher(cfg.CustomCategories)

        // Pre-compute and cache category embeddings
        if err := similarityMatcher.InitializeEmbeddings(); err != nil {
            return nil, fmt.Errorf("failed to initialize similarity matcher: %w", err)
        }

        observability.Infof("Initialized similarity matcher with %d custom categories",
            len(cfg.CustomCategories.Categories))
    }

    router := &OpenAIRouter{
        Config:            cfg,
        Classifier:        classifier,
        SimilarityMatcher: similarityMatcher,
        // ... other fields ...
    }

    return router, nil
}

8. Observability & Monitoring

Add comprehensive logging and metrics:

  • Log similarity scores for top-K categories
  • Track custom category hit rate vs. fallback rate
  • Monitor embedding generation latency
  • Record ambiguous matches (small gap between top-2) for tuning
  • Expose Prometheus metrics for routing decisions

Additional context

Design Reference: The detailed design is documented in customize-category.docx, which includes:

  • Evaluation of all-MiniLM-L12-v2 embedding model suitability
  • Alternative high-performance embedding models (BGE, M3E, multilingual models)
  • Best practices for category organization and maintenance
  • Similarity matching strategies and threshold tuning
  • Integration patterns and fallback mechanisms

Existing Infrastructure:

  • BERT similarity model is already initialized in the system (bert_model config)
  • Embedding generation functions exist in candle-binding (get_text_embedding, calculate_similarity)
  • Tool selection already uses similarity-based matching (src/semantic-router/pkg/tools/tools.go)
  • Config-based routing is the primary mechanism (SemanticRoute API examples exist but are not implemented)

Benefits:

  1. Zero-Shot Capability: Add new categories without retraining models
  2. Rapid Iteration: Update category definitions in minutes vs. hours/days for retraining
  3. Domain Flexibility: Support unlimited custom categories for diverse use cases
  4. Cost Efficiency: Eliminate expensive model retraining cycles
  5. Backward Compatible: Existing ModernBERT routing remains functional as fallback

Technical Considerations:

  • Embedding model should be lightweight for low latency (all-MiniLM-L12-v2: 384-dim, ~120MB)
  • Category embeddings should be pre-computed and cached in memory
  • Cosine similarity calculation is efficient for <100 categories (brute-force acceptable)
  • For >1000 categories, consider ANN libraries (FAISS, HNSWlib) for optimization
  • Ensure embedding space consistency (same model for queries and categories)

Use Cases:

  1. Travel & Tourism: Route travel queries to specialized travel models without retraining
  2. Legal Consulting: Add legal domain categories for law firms
  3. Healthcare: Create medical specialty categories (cardiology, neurology, etc.)
  4. E-commerce: Product recommendation, customer support, order tracking categories
  5. Education: Course-specific categories for educational platforms

Performance Characteristics:

Operation Latency Scalability
Embedding Generation ~10-15ms O(1) per query
Similarity Calculation ~5-10ms O(n) where n=categories
Total (Similarity Mode) ~15-25ms Efficient for <100 categories
ModernBERT Classification ~20-30ms O(1) model inference

Comparison with Fine-Tuned Classifier:

graph LR
    subgraph "Fine-Tuned Classifier"
        A1[Query] --> B1[ModernBERT Inference]
        B1 --> C1[14 Fixed Categories]
        C1 --> D1[Model Selection]
        D1 --> E1[Result]

        style B1 fill:#f8bbd0
        style C1 fill:#ffcdd2
    end

    subgraph "Similarity-Based Matcher"
        A2[Query] --> B2[BERT Embedding]
        B2 --> C2[Cosine Similarity]
        C2 --> D2[Unlimited Custom Categories]
        D2 --> E2[Model Selection]
        E2 --> F2[Result]

        style B2 fill:#ffe0b2
        style D2 fill:#c8e6c9
    end
Loading

Testing Requirements:

  1. Unit Tests:

    • Embedding generation and normalization
    • Cosine similarity calculation
    • Threshold validation (above/below threshold)
    • Gap validation (confident/ambiguous matches)
    • Fallback logic when no confident match
    • Category embedding cache management
  2. Integration Tests:

    • Similarity matching with confident match
    • Similarity matching with low confidence → fallback to classifier
    • Similarity matching with small gap → fallback to classifier
    • Disabled custom categories → use classifier only
    • Multiple categories with different similarity scores
  3. E2E Tests:

    • Real queries with expected category matches
    • Performance benchmarks (latency, throughput)
    • Stress tests with many custom categories
  4. Configuration Tests:

    • Valid custom category configuration
    • Invalid configuration handling (missing fields, invalid thresholds)
    • Dynamic category updates

Implementation Phases:

Phase 1: Core Implementation (2-3 weeks)

  • Implement SimilarityMatcher with embedding generation
  • Add configuration structures
  • Implement cosine similarity calculation
  • Add unit tests for similarity matching

Phase 2: Integration (1-2 weeks)

  • Integrate with OpenAIRouter
  • Implement fallback logic to ModernBERT classifier
  • Add integration tests

Phase 3: Observability (1 week)

  • Add metrics for similarity matching
  • Add detailed logging
  • Add performance benchmarks
  • Add threshold tuning guidelines

Observability & Metrics:

# New metrics for similarity-based routing
similarity_routing_matches_total{category_id, confident}  # Counter of category matches
similarity_routing_score_histogram{category_id}           # Histogram of similarity scores
similarity_routing_gap_histogram{category_id}             # Histogram of gaps (top-1 vs top-2)
similarity_routing_duration_seconds{stage}                # Histogram of latency by stage
similarity_routing_fallback_total{reason}                 # Counter of fallbacks (low_score, small_gap, error)
similarity_embedding_cache_size                           # Gauge of cached embeddings

Example Logging:

[INFO] Custom categories enabled with 5 categories
[INFO] Initializing embeddings for category: travel-planning
[INFO] Initializing embeddings for category: legal-consulting
[INFO] Initializing embeddings for category: tech-support
[INFO] Initializing embeddings for category: healthcare-advice
[INFO] Initializing embeddings for category: financial-planning
[INFO] Similarity matcher initialized successfully
[INFO]
[INFO] Query: "Recommend a 3-day itinerary for Paris with museums"
[INFO] Generated query embedding (384 dimensions)
[INFO] Similarity scores:
[INFO]   - travel-planning: 0.88 (gap: 0.46)
[INFO]   - legal-consulting: 0.42
[INFO]   - tech-support: 0.35
[INFO]   - healthcare-advice: 0.28
[INFO]   - financial-planning: 0.25
[INFO] Confident match: travel-planning (similarity: 0.88 >= 0.75, gap: 0.46 >= 0.05)
[INFO] Candidate models: [travel-expert, general-assistant]
[INFO] Selected model: travel-expert (score: 0.90)
[INFO] Similarity routing latency: 18ms

Future Enhancements:

  • Support multiple embedding models (BGE for Chinese, MPNet for higher accuracy)
  • Implement category embedding versioning for model upgrades
  • Add A/B testing framework for similarity threshold tuning
  • Support hierarchical categories with parent-child relationships
  • Integrate with vector databases (Milvus, Qdrant) for large-scale deployments (>1000 categories)
  • Add category similarity visualization for debugging
  • Support category aliases and synonyms

Related Files:

  • src/semantic-router/pkg/config/config.go - Configuration structures
  • src/semantic-router/pkg/extproc/request_handler.go - Request routing logic
  • src/semantic-router/pkg/utils/similarity/matcher.go - Similarity matcher (NEW)
  • candle-binding/semantic-router.go - BERT embedding functions
  • config/config.yaml - Configuration example

Related Issues:

/area core
/milestone v0.1
/priority P0

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions