Skip to content

Commit 0941db0

Browse files
committed
feat: add pluggable interface for semantic cache backends
- Create CacheBackend interface with pluggable architecture - Refactor existing in-memory cache to implement new interface - Add cache factory pattern for backend selection - Support configurable similarity thresholds and TTL - Add comprehensive cache metrics and observability Addresses #94 Signed-off-by: Huamin Chen <[email protected]>
1 parent 464ed6c commit 0941db0

File tree

12 files changed

+1553
-854
lines changed

12 files changed

+1553
-854
lines changed

config/config.yaml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,25 @@ bert_model:
44
use_cpu: true
55
semantic_cache:
66
enabled: true
7+
backend_type: "memory" # Options: "memory" or "milvus"
78
similarity_threshold: 0.8
8-
max_entries: 1000
9+
max_entries: 1000 # Only applies to memory backend
910
ttl_seconds: 3600
11+
12+
# For production environments, use Milvus for scalable caching:
13+
# backend_type: "milvus"
14+
# backend_config_path: "config/cache/milvus.yaml"
15+
16+
# Development/Testing: Use in-memory cache (current configuration)
17+
# - Fast startup and no external dependencies
18+
# - Limited to single instance scaling
19+
# - Data lost on restart
20+
21+
# Production: Use Milvus vector database
22+
# - Horizontally scalable and persistent
23+
# - Supports distributed deployments
24+
# - Requires Milvus cluster setup
25+
# - To enable: uncomment the lines above and install Milvus dependencies
1026
tools:
1127
enabled: true # Set to true to enable automatic tool selection
1228
top_k: 3 # Number of most relevant tools to select

src/semantic-router/pkg/cache/cache.go

Lines changed: 6 additions & 282 deletions
Original file line numberDiff line numberDiff line change
@@ -3,315 +3,39 @@ package cache
33
import (
44
"encoding/json"
55
"fmt"
6-
"log"
7-
"sort"
8-
"sync"
9-
"time"
10-
11-
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
126
)
137

14-
// CacheEntry represents a cached request-response pair
15-
type CacheEntry struct {
16-
RequestBody []byte
17-
ResponseBody []byte
18-
Model string
19-
Query string
20-
Embedding []float32
21-
Timestamp time.Time
22-
}
23-
24-
// SemanticCache implements a semantic cache using BERT embeddings
25-
type SemanticCache struct {
26-
entries []CacheEntry
27-
mu sync.RWMutex
28-
similarityThreshold float32
29-
maxEntries int
30-
ttlSeconds int
31-
enabled bool
32-
}
33-
34-
// SemanticCacheOptions holds options for creating a new semantic cache
35-
type SemanticCacheOptions struct {
36-
SimilarityThreshold float32
37-
MaxEntries int
38-
TTLSeconds int
39-
Enabled bool
40-
}
41-
42-
// NewSemanticCache creates a new semantic cache with the given options
43-
func NewSemanticCache(options SemanticCacheOptions) *SemanticCache {
44-
return &SemanticCache{
45-
entries: []CacheEntry{},
46-
similarityThreshold: options.SimilarityThreshold,
47-
maxEntries: options.MaxEntries,
48-
ttlSeconds: options.TTLSeconds,
49-
enabled: options.Enabled,
50-
}
51-
}
52-
53-
// IsEnabled returns whether the cache is enabled
54-
func (c *SemanticCache) IsEnabled() bool {
55-
return c.enabled
56-
}
57-
58-
// AddPendingRequest adds a pending request to the cache (without response yet)
59-
func (c *SemanticCache) AddPendingRequest(model string, query string, requestBody []byte) (string, error) {
60-
if !c.enabled {
61-
return query, nil
62-
}
63-
64-
// Generate embedding for the query
65-
embedding, err := candle_binding.GetEmbedding(query, 512)
66-
if err != nil {
67-
return "", fmt.Errorf("failed to generate embedding: %w", err)
68-
}
69-
70-
c.mu.Lock()
71-
defer c.mu.Unlock()
72-
73-
// Cleanup expired entries if TTL is set
74-
c.cleanupExpiredEntries()
75-
76-
// Create a new entry with the pending request
77-
entry := CacheEntry{
78-
RequestBody: requestBody,
79-
Model: model,
80-
Query: query,
81-
Embedding: embedding,
82-
Timestamp: time.Now(),
83-
}
84-
85-
c.entries = append(c.entries, entry)
86-
// log.Printf("Added pending cache entry for: %s", query)
87-
88-
// Enforce max entries limit if set
89-
if c.maxEntries > 0 && len(c.entries) > c.maxEntries {
90-
// Sort by timestamp (oldest first)
91-
sort.Slice(c.entries, func(i, j int) bool {
92-
return c.entries[i].Timestamp.Before(c.entries[j].Timestamp)
93-
})
94-
// Remove oldest entries
95-
c.entries = c.entries[len(c.entries)-c.maxEntries:]
96-
log.Printf("Trimmed cache to %d entries", c.maxEntries)
97-
}
98-
99-
return query, nil
100-
}
101-
102-
// UpdateWithResponse updates a pending request with its response
103-
func (c *SemanticCache) UpdateWithResponse(query string, responseBody []byte) error {
104-
if !c.enabled {
105-
return nil
106-
}
107-
108-
c.mu.Lock()
109-
defer c.mu.Unlock()
110-
111-
// Cleanup expired entries while we have the write lock
112-
c.cleanupExpiredEntries()
113-
114-
// Find the pending request by query
115-
for i, entry := range c.entries {
116-
if entry.Query == query && entry.ResponseBody == nil {
117-
// Update with response
118-
c.entries[i].ResponseBody = responseBody
119-
c.entries[i].Timestamp = time.Now()
120-
// log.Printf("Cache entry updated: %s", query)
121-
return nil
122-
}
123-
}
124-
125-
return fmt.Errorf("no pending request found for query: %s", query)
126-
}
127-
128-
// AddEntry adds a complete entry to the cache
129-
func (c *SemanticCache) AddEntry(model string, query string, requestBody, responseBody []byte) error {
130-
if !c.enabled {
131-
return nil
132-
}
133-
134-
// Generate embedding for the query
135-
embedding, err := candle_binding.GetEmbedding(query, 512)
136-
if err != nil {
137-
return fmt.Errorf("failed to generate embedding: %w", err)
138-
}
139-
140-
entry := CacheEntry{
141-
RequestBody: requestBody,
142-
ResponseBody: responseBody,
143-
Model: model,
144-
Query: query,
145-
Embedding: embedding,
146-
Timestamp: time.Now(),
147-
}
148-
149-
c.mu.Lock()
150-
defer c.mu.Unlock()
151-
152-
// Cleanup expired entries
153-
c.cleanupExpiredEntries()
154-
155-
c.entries = append(c.entries, entry)
156-
log.Printf("Added cache entry: %s", query)
157-
158-
// Enforce max entries limit
159-
if c.maxEntries > 0 && len(c.entries) > c.maxEntries {
160-
// Sort by timestamp (oldest first)
161-
sort.Slice(c.entries, func(i, j int) bool {
162-
return c.entries[i].Timestamp.Before(c.entries[j].Timestamp)
163-
})
164-
// Remove oldest entries
165-
c.entries = c.entries[len(c.entries)-c.maxEntries:]
166-
}
167-
168-
return nil
169-
}
170-
171-
// FindSimilar looks for a similar request in the cache
172-
func (c *SemanticCache) FindSimilar(model string, query string) ([]byte, bool, error) {
173-
if !c.enabled {
174-
return nil, false, nil
175-
}
176-
177-
// Generate embedding for the query
178-
queryEmbedding, err := candle_binding.GetEmbedding(query, 512)
179-
if err != nil {
180-
return nil, false, fmt.Errorf("failed to generate embedding: %w", err)
181-
}
182-
183-
c.mu.RLock()
184-
defer c.mu.RUnlock()
185-
186-
// Cleanup expired entries
187-
c.cleanupExpiredEntriesReadOnly()
188-
189-
type SimilarityResult struct {
190-
Entry CacheEntry
191-
Similarity float32
192-
}
193-
194-
// Only compare with entries that have responses
195-
results := make([]SimilarityResult, 0, len(c.entries))
196-
for _, entry := range c.entries {
197-
if entry.ResponseBody == nil {
198-
continue // Skip entries without responses
199-
}
200-
201-
// Only compare with entries with the same model
202-
if entry.Model != model {
203-
continue
204-
}
205-
206-
// Calculate similarity
207-
var dotProduct float32
208-
for i := 0; i < len(queryEmbedding) && i < len(entry.Embedding); i++ {
209-
dotProduct += queryEmbedding[i] * entry.Embedding[i]
210-
}
211-
212-
results = append(results, SimilarityResult{
213-
Entry: entry,
214-
Similarity: dotProduct,
215-
})
216-
}
217-
218-
// No results found
219-
if len(results) == 0 {
220-
return nil, false, nil
221-
}
222-
223-
// Sort by similarity (highest first)
224-
sort.Slice(results, func(i, j int) bool {
225-
return results[i].Similarity > results[j].Similarity
226-
})
227-
228-
// Check if the best match exceeds the threshold
229-
if results[0].Similarity >= c.similarityThreshold {
230-
log.Printf("Cache hit: similarity=%.4f, threshold=%.4f",
231-
results[0].Similarity, c.similarityThreshold)
232-
return results[0].Entry.ResponseBody, true, nil
233-
}
234-
235-
log.Printf("Cache miss: best similarity=%.4f, threshold=%.4f",
236-
results[0].Similarity, c.similarityThreshold)
237-
return nil, false, nil
238-
}
239-
240-
// cleanupExpiredEntries removes expired entries from the cache
241-
// Assumes the caller holds a write lock
242-
func (c *SemanticCache) cleanupExpiredEntries() {
243-
if c.ttlSeconds <= 0 {
244-
return
245-
}
246-
247-
now := time.Now()
248-
validEntries := make([]CacheEntry, 0, len(c.entries))
249-
250-
for _, entry := range c.entries {
251-
// Keep entries that haven't expired
252-
if now.Sub(entry.Timestamp).Seconds() < float64(c.ttlSeconds) {
253-
validEntries = append(validEntries, entry)
254-
}
255-
}
256-
257-
if len(validEntries) < len(c.entries) {
258-
log.Printf("Removed %d expired cache entries", len(c.entries)-len(validEntries))
259-
c.entries = validEntries
260-
}
261-
}
262-
263-
// cleanupExpiredEntriesReadOnly checks for expired entries but doesn't modify the cache
264-
// Used during read operations where we only have a read lock
265-
func (c *SemanticCache) cleanupExpiredEntriesReadOnly() {
266-
if c.ttlSeconds <= 0 {
267-
return
268-
}
269-
270-
now := time.Now()
271-
expiredCount := 0
272-
273-
for _, entry := range c.entries {
274-
if now.Sub(entry.Timestamp).Seconds() >= float64(c.ttlSeconds) {
275-
expiredCount++
276-
}
277-
}
278-
279-
if expiredCount > 0 {
280-
log.Printf("Found %d expired cache entries during read operation", expiredCount)
281-
}
282-
}
283-
284-
// ChatMessage represents a message in the OpenAI chat format
8+
// ChatMessage represents a message in the OpenAI chat format with role and content
2859
type ChatMessage struct {
28610
Role string `json:"role"`
28711
Content string `json:"content"`
28812
}
28913

290-
// OpenAIRequest represents an OpenAI API request
14+
// OpenAIRequest represents the structure of an OpenAI API request
29115
type OpenAIRequest struct {
29216
Model string `json:"model"`
29317
Messages []ChatMessage `json:"messages"`
29418
}
29519

296-
// ExtractQueryFromOpenAIRequest extracts the user query from an OpenAI request
20+
// ExtractQueryFromOpenAIRequest parses an OpenAI request and extracts the user query
29721
func ExtractQueryFromOpenAIRequest(requestBody []byte) (string, string, error) {
29822
var req OpenAIRequest
29923
if err := json.Unmarshal(requestBody, &req); err != nil {
30024
return "", "", fmt.Errorf("invalid request body: %w", err)
30125
}
30226

303-
// Extract user messages
27+
// Find user messages in the conversation
30428
var userMessages []string
30529
for _, msg := range req.Messages {
30630
if msg.Role == "user" {
30731
userMessages = append(userMessages, msg.Content)
30832
}
30933
}
31034

311-
// Join all user messages
35+
// Use the most recent user message as the query
31236
query := ""
31337
if len(userMessages) > 0 {
314-
query = userMessages[len(userMessages)-1] // Use the last user message
38+
query = userMessages[len(userMessages)-1]
31539
}
31640

31741
return req.Model, query, nil

0 commit comments

Comments
 (0)