diff --git a/website/docs/architecture/router-implementation.md b/website/docs/architecture/router-implementation.md index c5b5b9fd..dc2e4317 100644 --- a/website/docs/architecture/router-implementation.md +++ b/website/docs/architecture/router-implementation.md @@ -103,51 +103,86 @@ func (r *OpenAIRouter) makeRoutingDecision(classification *Classification) *Rout ```go type SemanticCache struct { - embeddings map[string][]float32 // Query embeddings - responses map[string]CachedResponse - similarity SimilarityCalculator - ttl time.Duration - maxEntries int - mutex sync.RWMutex + entries []CacheEntry + mu sync.RWMutex + similarityThreshold float32 + maxEntries int + ttlSeconds int + enabled bool } -type CachedResponse struct { - Response interface{} - Timestamp time.Time - Model string - Embeddings []float32 - HitCount int +type CacheEntry struct { + RequestBody []byte + ResponseBody []byte + Model string + Query string + Embedding []float32 + Timestamp time.Time } -// Cache lookup with semantic similarity -func (sc *SemanticCache) Get(query string) (interface{}, bool) { - sc.mutex.RLock() - defer sc.mutex.RUnlock() - - // Generate query embedding - queryEmbedding := sc.generateEmbedding(query) - - // Find most similar cached query - bestSimilarity := 0.0 - var bestMatch *CachedResponse - - for cachedQuery, embedding := range sc.embeddings { - similarity := sc.similarity.CosineSimilarity(queryEmbedding, embedding) - - if similarity > bestSimilarity && similarity > sc.similarityThreshold { - bestSimilarity = similarity - if response, exists := sc.responses[cachedQuery]; exists { - bestMatch = &response - } +// FindSimilar looks for a similar request in the cache +func (c *SemanticCache) FindSimilar(model string, query string) ([]byte, bool, error) { + if !c.enabled { + return nil, false, nil + } + + // Generate embedding for the query + queryEmbedding, err := candle_binding.GetEmbedding(query, 512) + if err != nil { + return nil, false, fmt.Errorf("failed to generate embedding: %w", err) + } + + c.mu.RLock() + defer c.mu.RUnlock() + + // Cleanup expired entries + c.cleanupExpiredEntriesReadOnly() + + type SimilarityResult struct { + Entry CacheEntry + Similarity float32 + } + + // Only compare with entries that have responses + results := make([]SimilarityResult, 0, len(c.entries)) + for _, entry := range c.entries { + if entry.ResponseBody == nil { + continue // Skip entries without responses } + + // Only compare with entries with the same model + if entry.Model != model { + continue + } + + // Calculate similarity using dot product + var dotProduct float32 + for i := 0; i < len(queryEmbedding) && i < len(entry.Embedding); i++ { + dotProduct += queryEmbedding[i] * entry.Embedding[i] + } + + results = append(results, SimilarityResult{ + Entry: entry, + Similarity: dotProduct, + }) } - - if bestMatch != nil && time.Since(bestMatch.Timestamp) < sc.ttl { - bestMatch.HitCount++ - return bestMatch.Response, true + + // No results found + if len(results) == 0 { + return nil, false, nil } - - return nil, false + + // Sort by similarity (highest first) + sort.Slice(results, func(i, j int) bool { + return results[i].Similarity > results[j].Similarity + }) + + // Check if the best match exceeds the threshold + if results[0].Similarity >= c.similarityThreshold { + return results[0].Entry.ResponseBody, true, nil + } + + return nil, false, nil } ```