Skip to content

Commit e9e4fad

Browse files
committed
Checkpoint
1 parent 5021a42 commit e9e4fad

File tree

4 files changed

+442
-83
lines changed

4 files changed

+442
-83
lines changed

redisvl/extensions/llmcache/langcache_api.py

Lines changed: 146 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Union
33

44
from langcache import LangCache as LangCacheSDK
5+
from langcache.models import CacheEntryScope, CacheEntryScopeTypedDict
56

67
from redisvl.extensions.llmcache.base import BaseLLMCache
78
from redisvl.query.filter import FilterExpression
8-
from redisvl.utils.utils import current_timestamp, hashify
9+
from redisvl.utils.utils import current_timestamp
10+
11+
Scope = Optional[Union[CacheEntryScope, CacheEntryScopeTypedDict]]
912

1013

1114
class LangCache(BaseLLMCache):
@@ -20,6 +23,7 @@ def __init__(
2023
redis_url: str = "redis://localhost:6379",
2124
connection_kwargs: Dict[str, Any] = {},
2225
overwrite: bool = False,
26+
entry_scope: Scope = None,
2327
**kwargs,
2428
):
2529
"""Initialize a LangCache client.
@@ -32,6 +36,7 @@ def __init__(
3236
redis_url: URL for Redis connection if no client is provided.
3337
connection_kwargs: Additional Redis connection parameters.
3438
overwrite: Whether to overwrite an existing cache with the same name.
39+
entry_scope: Optional scope for cache entries.
3540
"""
3641
# Initialize the base class
3742
super().__init__(ttl)
@@ -43,7 +48,7 @@ def __init__(
4348
self._distance_threshold = distance_threshold
4449
self._ttl = ttl
4550
self._cache_id = name
46-
51+
self._entry_scope = entry_scope
4752
# Initialize LangCache SDK client
4853
self._api = LangCacheSDK(server_url=redis_url, client=redis_client)
4954

@@ -101,12 +106,28 @@ def set_ttl(self, ttl: Optional[int] = None) -> None:
101106

102107
def clear(self) -> None:
103108
"""Clear all entries from the cache while preserving the cache configuration."""
104-
self._api.entries.delete_all(cache_id=self._cache_id, attributes={}, scope={})
109+
self._api.entries.delete_all(
110+
cache_id=self._cache_id,
111+
attributes={},
112+
scope=(
113+
self._entry_scope
114+
if self._entry_scope is not None
115+
else CacheEntryScope()
116+
),
117+
)
105118

106119
async def aclear(self) -> None:
107120
"""Asynchronously clear all entries from the cache."""
108-
# Currently using synchronous implementation since langcache doesn't have async API
109-
self.clear()
121+
# Use the SDK's async delete_all
122+
await self._api.entries.delete_all_async(
123+
cache_id=self._cache_id,
124+
attributes={},
125+
scope=(
126+
self._entry_scope
127+
if self._entry_scope is not None
128+
else CacheEntryScope()
129+
),
130+
)
110131

111132
def delete(self) -> None:
112133
"""Delete the cache and all its entries."""
@@ -115,8 +136,17 @@ def delete(self) -> None:
115136

116137
async def adelete(self) -> None:
117138
"""Asynchronously delete the cache and all its entries."""
118-
# Currently using synchronous implementation since langcache doesn't have async API
119-
self.delete()
139+
# Clear entries then delete cache asynchronously
140+
await self._api.entries.delete_all_async(
141+
cache_id=self._cache_id,
142+
attributes={},
143+
scope=(
144+
self._entry_scope
145+
if self._entry_scope is not None
146+
else CacheEntryScope()
147+
),
148+
)
149+
await self._api.cache.delete_async(cache_id=self._cache_id)
120150

121151
def drop(
122152
self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None
@@ -134,14 +164,13 @@ def drop(
134164
async def adrop(
135165
self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None
136166
) -> None:
137-
"""Asynchronously remove specific entries from the cache.
138-
139-
Args:
140-
ids: List of entry IDs to remove.
141-
keys: List of Redis keys to remove.
142-
"""
143-
# Currently using synchronous implementation since langcache doesn't have async API
144-
self.drop(ids, keys)
167+
"""Asynchronously remove specific entries from the cache."""
168+
# Use the SDK's async delete for each entry
169+
if ids:
170+
for entry_id in ids:
171+
await self._api.entries.delete_async(
172+
entry_id=entry_id, cache_id=self._cache_id
173+
)
145174

146175
def check(
147176
self,
@@ -151,6 +180,7 @@ def check(
151180
return_fields: Optional[List[str]] = None,
152181
filter_expression: Optional[FilterExpression] = None,
153182
distance_threshold: Optional[float] = None,
183+
entry_scope: Optional[Dict[str, Any]] = None,
154184
) -> List[Dict[str, Any]]:
155185
"""Check the cache for semantically similar entries.
156186
@@ -161,7 +191,7 @@ def check(
161191
return_fields: Fields to include in the response.
162192
filter_expression: Optional filter for the search.
163193
distance_threshold: Override the default distance threshold.
164-
194+
entry_scope: Optional scope for cache entries.
165195
Returns:
166196
List of matching cache entries.
167197
@@ -172,18 +202,22 @@ def check(
172202
if not any([prompt, vector]):
173203
raise ValueError("Either prompt or vector must be provided")
174204

205+
_scope = entry_scope or self._entry_scope
206+
175207
if return_fields and not isinstance(return_fields, list):
176208
raise TypeError("return_fields must be a list")
177209

178210
# Use provided threshold or default
179211
threshold = distance_threshold or self._distance_threshold
180212

181-
# Search the cache - note we don't use scope since FilterExpression conversion would be complex
182-
# and require proper implementation for CacheEntryScope format
213+
# Search the cache - note we don't use scope since FilterExpression conversion
214+
# would be complex (impossible?)
183215
results = self._api.entries.search(
184216
cache_id=self._cache_id,
185217
prompt=prompt or "", # Ensure prompt is never None
186218
similarity_threshold=threshold,
219+
# Type-cast is necessary to handle the scope type correctly
220+
scope=_scope, # type: ignore[arg-type]
187221
)
188222

189223
# If we need to limit results and have more than requested, slice the list
@@ -235,18 +269,61 @@ async def acheck(
235269
return_fields: Optional[List[str]] = None,
236270
filter_expression: Optional[FilterExpression] = None,
237271
distance_threshold: Optional[float] = None,
272+
entry_scope: Scope = None,
238273
) -> List[Dict[str, Any]]:
239274
"""Asynchronously check the cache for semantically similar entries."""
240-
# Currently using synchronous implementation since langcache doesn't have async API
241-
return self.check(
242-
prompt,
243-
vector,
244-
num_results,
245-
return_fields,
246-
filter_expression,
247-
distance_threshold,
275+
# Validate inputs
276+
if not any([prompt, vector]):
277+
raise ValueError("Either prompt or vector must be provided")
278+
if return_fields and not isinstance(return_fields, list):
279+
raise TypeError("return_fields must be a list")
280+
281+
# Determine scope to use
282+
_scope = entry_scope or self._entry_scope
283+
284+
# Determine threshold
285+
threshold = distance_threshold or self._distance_threshold
286+
287+
# Perform async search
288+
results = await self._api.entries.search_async(
289+
cache_id=self._cache_id,
290+
prompt=prompt or "",
291+
similarity_threshold=threshold,
292+
# Type-cast is necessary to handle the scope type correctly
293+
scope=_scope, # type: ignore[arg-type]
248294
)
249295

296+
# Limit results
297+
if num_results < len(results):
298+
results = results[:num_results]
299+
300+
# Format hits
301+
cache_hits: List[Dict[str, Any]] = []
302+
for result in results:
303+
hit = {
304+
"key": result.id,
305+
"entry_id": result.id,
306+
"prompt": result.prompt,
307+
"response": result.response,
308+
"vector_distance": result.similarity,
309+
}
310+
if hasattr(result, "metadata") and result.metadata:
311+
try:
312+
metadata_dict = {}
313+
if hasattr(result.metadata, "__dict__"):
314+
metadata_dict = {
315+
k: v
316+
for k, v in result.metadata.__dict__.items()
317+
if not k.startswith("_")
318+
}
319+
hit["metadata"] = metadata_dict
320+
except Exception:
321+
hit["metadata"] = {}
322+
if return_fields:
323+
hit = {k: v for k, v in hit.items() if k in return_fields or k == "key"}
324+
cache_hits.append(hit)
325+
return cache_hits
326+
250327
def store(
251328
self,
252329
prompt: str,
@@ -261,7 +338,7 @@ def store(
261338
Args:
262339
prompt: The prompt text.
263340
response: The response text.
264-
vector: Optional vector representation of the prompt.
341+
vector: Unused. LangCache manages vectorization internally.
265342
metadata: Optional metadata to store with the entry.
266343
filters: Optional filters to associate with the entry.
267344
ttl: Optional custom TTL for this entry.
@@ -312,84 +389,74 @@ async def astore(
312389
ttl: Optional[int] = None,
313390
) -> str:
314391
"""Asynchronously store a new entry in the cache."""
315-
# Currently using synchronous implementation since langcache doesn't have async API
316-
return self.store(prompt, response, vector, metadata, filters, ttl)
317-
318-
def update(self, key: str, **kwargs) -> None:
319-
"""Update an existing cache entry.
320-
321-
Args:
322-
key: The entry ID to update.
323-
**kwargs: Fields to update (prompt, response, metadata, etc.)
324-
"""
325-
# Find the entry to update
326-
existing_entries = self._api.entries.search(
327-
cache_id=self._cache_id,
328-
prompt="", # Required parameter but we're searching by ID
329-
attributes={"id": key}, # Search by ID as an attribute
330-
similarity_threshold=1.0, # We're not doing semantic search
331-
)
332-
333-
if not existing_entries:
334-
return
392+
# Validate metadata
393+
if metadata is not None and not isinstance(metadata, dict):
394+
raise ValueError("Metadata must be a dictionary")
335395

336-
existing_entry = existing_entries[0]
396+
# Create entry with optional TTL
397+
entry_ttl = ttl if ttl is not None else self._ttl
337398

338-
# Prepare updated values
339-
# CacheEntry objects are Pydantic models, access their attributes directly
340-
prompt = kwargs.get(
341-
"prompt", existing_entry.prompt if hasattr(existing_entry, "prompt") else ""
342-
)
343-
response = kwargs.get(
344-
"response",
345-
existing_entry.response if hasattr(existing_entry, "response") else "",
346-
)
399+
# Convert ttl to ttl_millis (milliseconds) if provided
400+
ttl_millis = entry_ttl * 1000 if entry_ttl is not None else None
347401

348-
# Prepare attributes for update
402+
# Process additional attributes from filters
349403
attributes = {}
350-
if "metadata" in kwargs:
404+
if filters:
405+
attributes.update(filters)
406+
407+
# Add metadata to attributes if provided
408+
if metadata:
351409
attributes["metadata"] = (
352-
json.dumps(kwargs["metadata"])
353-
if isinstance(kwargs["metadata"], dict)
354-
else kwargs["metadata"]
410+
json.dumps(metadata) if isinstance(metadata, dict) else metadata
355411
)
356412

357-
# Convert TTL to milliseconds if provided
358-
ttl = kwargs.get("ttl", None)
359-
ttl_millis = ttl * 1000 if ttl is not None else None
360-
361-
# Re-create the entry with updated values
362-
self._api.entries.create(
413+
# Store the entry and get the response
414+
create_response = await self._api.entries.create_async(
363415
cache_id=self._cache_id,
364416
prompt=prompt,
365417
response=response,
366418
attributes=attributes,
367419
ttl_millis=ttl_millis,
368420
)
369421

422+
# Return the entry ID from the response
423+
return create_response.entry_id
424+
425+
def update(self, key: str, **kwargs) -> None:
426+
"""Update an existing cache entry.
427+
428+
Args:
429+
key: The entry ID to update.
430+
**kwargs: Fields to update (prompt, response, metadata, etc.)
431+
"""
432+
raise NotImplementedError("LangCache SDK does not support update in place")
433+
370434
async def aupdate(self, key: str, **kwargs) -> None:
371435
"""Asynchronously update an existing cache entry."""
372-
# Currently using synchronous implementation since langcache doesn't have async API
373-
self.update(key, **kwargs)
436+
raise NotImplementedError("LangCache SDK does not support update in place")
374437

375438
def disconnect(self) -> None:
376-
"""Close the Redis connection."""
377-
# Redis clients typically don't need explicit disconnection,
378-
# as they use connection pooling
379-
pass
439+
"""Close the connection."""
440+
if self._api.sdk_configuration.client is not None:
441+
self._api.sdk_configuration.client.close()
380442

381443
async def adisconnect(self) -> None:
382-
"""Asynchronously close the Redis connection."""
383-
self.disconnect()
444+
"""Asynchronously close the connection."""
445+
if self._api.sdk_configuration.async_client is not None:
446+
await self._api.sdk_configuration.async_client.aclose()
447+
if self._api.sdk_configuration.client is not None:
448+
self._api.sdk_configuration.client.close()
384449

385450
def __enter__(self):
451+
self._api.__enter__()
386452
return self
387453

388454
def __exit__(self, exc_type, exc_val, exc_tb):
389-
self.disconnect()
455+
self._api.__exit__(exc_type, exc_val, exc_tb)
390456

391457
async def __aenter__(self):
458+
await self._api.__aenter__()
392459
return self
393460

394461
async def __aexit__(self, exc_type, exc_val, exc_tb):
395-
await self.adisconnect()
462+
await self._api.__aexit__(exc_type, exc_val, exc_tb)

0 commit comments

Comments
 (0)