@@ -3,315 +3,39 @@ package cache
33import (
44 "encoding/json"
55 "fmt"
6- "log"
7- "sort"
8- "sync"
9- "time"
10-
11- candle_binding "github.com/vllm-project/semantic-router/candle-binding"
126)
137
14- // CacheEntry represents a cached request-response pair
15- type CacheEntry struct {
16- RequestBody []byte
17- ResponseBody []byte
18- Model string
19- Query string
20- Embedding []float32
21- Timestamp time.Time
22- }
23-
24- // SemanticCache implements a semantic cache using BERT embeddings
25- type SemanticCache struct {
26- entries []CacheEntry
27- mu sync.RWMutex
28- similarityThreshold float32
29- maxEntries int
30- ttlSeconds int
31- enabled bool
32- }
33-
34- // SemanticCacheOptions holds options for creating a new semantic cache
35- type SemanticCacheOptions struct {
36- SimilarityThreshold float32
37- MaxEntries int
38- TTLSeconds int
39- Enabled bool
40- }
41-
42- // NewSemanticCache creates a new semantic cache with the given options
43- func NewSemanticCache (options SemanticCacheOptions ) * SemanticCache {
44- return & SemanticCache {
45- entries : []CacheEntry {},
46- similarityThreshold : options .SimilarityThreshold ,
47- maxEntries : options .MaxEntries ,
48- ttlSeconds : options .TTLSeconds ,
49- enabled : options .Enabled ,
50- }
51- }
52-
53- // IsEnabled returns whether the cache is enabled
54- func (c * SemanticCache ) IsEnabled () bool {
55- return c .enabled
56- }
57-
58- // AddPendingRequest adds a pending request to the cache (without response yet)
59- func (c * SemanticCache ) AddPendingRequest (model string , query string , requestBody []byte ) (string , error ) {
60- if ! c .enabled {
61- return query , nil
62- }
63-
64- // Generate embedding for the query
65- embedding , err := candle_binding .GetEmbedding (query , 512 )
66- if err != nil {
67- return "" , fmt .Errorf ("failed to generate embedding: %w" , err )
68- }
69-
70- c .mu .Lock ()
71- defer c .mu .Unlock ()
72-
73- // Cleanup expired entries if TTL is set
74- c .cleanupExpiredEntries ()
75-
76- // Create a new entry with the pending request
77- entry := CacheEntry {
78- RequestBody : requestBody ,
79- Model : model ,
80- Query : query ,
81- Embedding : embedding ,
82- Timestamp : time .Now (),
83- }
84-
85- c .entries = append (c .entries , entry )
86- // log.Printf("Added pending cache entry for: %s", query)
87-
88- // Enforce max entries limit if set
89- if c .maxEntries > 0 && len (c .entries ) > c .maxEntries {
90- // Sort by timestamp (oldest first)
91- sort .Slice (c .entries , func (i , j int ) bool {
92- return c .entries [i ].Timestamp .Before (c .entries [j ].Timestamp )
93- })
94- // Remove oldest entries
95- c .entries = c .entries [len (c .entries )- c .maxEntries :]
96- log .Printf ("Trimmed cache to %d entries" , c .maxEntries )
97- }
98-
99- return query , nil
100- }
101-
102- // UpdateWithResponse updates a pending request with its response
103- func (c * SemanticCache ) UpdateWithResponse (query string , responseBody []byte ) error {
104- if ! c .enabled {
105- return nil
106- }
107-
108- c .mu .Lock ()
109- defer c .mu .Unlock ()
110-
111- // Cleanup expired entries while we have the write lock
112- c .cleanupExpiredEntries ()
113-
114- // Find the pending request by query
115- for i , entry := range c .entries {
116- if entry .Query == query && entry .ResponseBody == nil {
117- // Update with response
118- c .entries [i ].ResponseBody = responseBody
119- c .entries [i ].Timestamp = time .Now ()
120- // log.Printf("Cache entry updated: %s", query)
121- return nil
122- }
123- }
124-
125- return fmt .Errorf ("no pending request found for query: %s" , query )
126- }
127-
128- // AddEntry adds a complete entry to the cache
129- func (c * SemanticCache ) AddEntry (model string , query string , requestBody , responseBody []byte ) error {
130- if ! c .enabled {
131- return nil
132- }
133-
134- // Generate embedding for the query
135- embedding , err := candle_binding .GetEmbedding (query , 512 )
136- if err != nil {
137- return fmt .Errorf ("failed to generate embedding: %w" , err )
138- }
139-
140- entry := CacheEntry {
141- RequestBody : requestBody ,
142- ResponseBody : responseBody ,
143- Model : model ,
144- Query : query ,
145- Embedding : embedding ,
146- Timestamp : time .Now (),
147- }
148-
149- c .mu .Lock ()
150- defer c .mu .Unlock ()
151-
152- // Cleanup expired entries
153- c .cleanupExpiredEntries ()
154-
155- c .entries = append (c .entries , entry )
156- log .Printf ("Added cache entry: %s" , query )
157-
158- // Enforce max entries limit
159- if c .maxEntries > 0 && len (c .entries ) > c .maxEntries {
160- // Sort by timestamp (oldest first)
161- sort .Slice (c .entries , func (i , j int ) bool {
162- return c .entries [i ].Timestamp .Before (c .entries [j ].Timestamp )
163- })
164- // Remove oldest entries
165- c .entries = c .entries [len (c .entries )- c .maxEntries :]
166- }
167-
168- return nil
169- }
170-
171- // FindSimilar looks for a similar request in the cache
172- func (c * SemanticCache ) FindSimilar (model string , query string ) ([]byte , bool , error ) {
173- if ! c .enabled {
174- return nil , false , nil
175- }
176-
177- // Generate embedding for the query
178- queryEmbedding , err := candle_binding .GetEmbedding (query , 512 )
179- if err != nil {
180- return nil , false , fmt .Errorf ("failed to generate embedding: %w" , err )
181- }
182-
183- c .mu .RLock ()
184- defer c .mu .RUnlock ()
185-
186- // Cleanup expired entries
187- c .cleanupExpiredEntriesReadOnly ()
188-
189- type SimilarityResult struct {
190- Entry CacheEntry
191- Similarity float32
192- }
193-
194- // Only compare with entries that have responses
195- results := make ([]SimilarityResult , 0 , len (c .entries ))
196- for _ , entry := range c .entries {
197- if entry .ResponseBody == nil {
198- continue // Skip entries without responses
199- }
200-
201- // Only compare with entries with the same model
202- if entry .Model != model {
203- continue
204- }
205-
206- // Calculate similarity
207- var dotProduct float32
208- for i := 0 ; i < len (queryEmbedding ) && i < len (entry .Embedding ); i ++ {
209- dotProduct += queryEmbedding [i ] * entry .Embedding [i ]
210- }
211-
212- results = append (results , SimilarityResult {
213- Entry : entry ,
214- Similarity : dotProduct ,
215- })
216- }
217-
218- // No results found
219- if len (results ) == 0 {
220- return nil , false , nil
221- }
222-
223- // Sort by similarity (highest first)
224- sort .Slice (results , func (i , j int ) bool {
225- return results [i ].Similarity > results [j ].Similarity
226- })
227-
228- // Check if the best match exceeds the threshold
229- if results [0 ].Similarity >= c .similarityThreshold {
230- log .Printf ("Cache hit: similarity=%.4f, threshold=%.4f" ,
231- results [0 ].Similarity , c .similarityThreshold )
232- return results [0 ].Entry .ResponseBody , true , nil
233- }
234-
235- log .Printf ("Cache miss: best similarity=%.4f, threshold=%.4f" ,
236- results [0 ].Similarity , c .similarityThreshold )
237- return nil , false , nil
238- }
239-
240- // cleanupExpiredEntries removes expired entries from the cache
241- // Assumes the caller holds a write lock
242- func (c * SemanticCache ) cleanupExpiredEntries () {
243- if c .ttlSeconds <= 0 {
244- return
245- }
246-
247- now := time .Now ()
248- validEntries := make ([]CacheEntry , 0 , len (c .entries ))
249-
250- for _ , entry := range c .entries {
251- // Keep entries that haven't expired
252- if now .Sub (entry .Timestamp ).Seconds () < float64 (c .ttlSeconds ) {
253- validEntries = append (validEntries , entry )
254- }
255- }
256-
257- if len (validEntries ) < len (c .entries ) {
258- log .Printf ("Removed %d expired cache entries" , len (c .entries )- len (validEntries ))
259- c .entries = validEntries
260- }
261- }
262-
263- // cleanupExpiredEntriesReadOnly checks for expired entries but doesn't modify the cache
264- // Used during read operations where we only have a read lock
265- func (c * SemanticCache ) cleanupExpiredEntriesReadOnly () {
266- if c .ttlSeconds <= 0 {
267- return
268- }
269-
270- now := time .Now ()
271- expiredCount := 0
272-
273- for _ , entry := range c .entries {
274- if now .Sub (entry .Timestamp ).Seconds () >= float64 (c .ttlSeconds ) {
275- expiredCount ++
276- }
277- }
278-
279- if expiredCount > 0 {
280- log .Printf ("Found %d expired cache entries during read operation" , expiredCount )
281- }
282- }
283-
284- // ChatMessage represents a message in the OpenAI chat format
8+ // ChatMessage represents a message in the OpenAI chat format with role and content
2859type ChatMessage struct {
28610 Role string `json:"role"`
28711 Content string `json:"content"`
28812}
28913
290- // OpenAIRequest represents an OpenAI API request
14+ // OpenAIRequest represents the structure of an OpenAI API request
29115type OpenAIRequest struct {
29216 Model string `json:"model"`
29317 Messages []ChatMessage `json:"messages"`
29418}
29519
296- // ExtractQueryFromOpenAIRequest extracts the user query from an OpenAI request
20+ // ExtractQueryFromOpenAIRequest parses an OpenAI request and extracts the user query
29721func ExtractQueryFromOpenAIRequest (requestBody []byte ) (string , string , error ) {
29822 var req OpenAIRequest
29923 if err := json .Unmarshal (requestBody , & req ); err != nil {
30024 return "" , "" , fmt .Errorf ("invalid request body: %w" , err )
30125 }
30226
303- // Extract user messages
27+ // Find user messages in the conversation
30428 var userMessages []string
30529 for _ , msg := range req .Messages {
30630 if msg .Role == "user" {
30731 userMessages = append (userMessages , msg .Content )
30832 }
30933 }
31034
311- // Join all user messages
35+ // Use the most recent user message as the query
31236 query := ""
31337 if len (userMessages ) > 0 {
314- query = userMessages [len (userMessages )- 1 ] // Use the last user message
38+ query = userMessages [len (userMessages )- 1 ]
31539 }
31640
31741 return req .Model , query , nil
0 commit comments