1+ import asyncio
12import json
23from typing import Any , Dict , List , Optional , Union
34
5+ from langcache import APIError , APIErrorResponse
46from langcache import LangCache as LangCacheSDK
57from langcache .models import CacheEntryScope , CacheEntryScopeTypedDict
68
79from redisvl .extensions .cache .llm .base import BaseLLMCache
10+ from redisvl .extensions .cache .llm .schema import LangCacheOptions
811from redisvl .query .filter import FilterExpression
9- from redisvl .utils .utils import current_timestamp
12+ from redisvl .utils .utils import (
13+ current_timestamp ,
14+ denorm_cosine_distance ,
15+ norm_cosine_distance ,
16+ )
1017
1118Scope = Optional [Union [CacheEntryScope , CacheEntryScopeTypedDict ]]
1219
1320
21+ class CacheNotFound (RuntimeError ):
22+ """The specified LangCache cache was not found."""
23+
24+
1425class LangCache (BaseLLMCache ):
1526 """Redis LangCache Service: API for managing a Redis LangCache"""
1627
1728 def __init__ (
1829 self ,
19- redis_client = None ,
2030 name : str = "llmcache" ,
31+ cache_id : Optional [str ] = None ,
2132 distance_threshold : float = 0.1 ,
2233 ttl : Optional [int ] = None ,
23- redis_url : str = "redis://localhost:6379" ,
24- connection_kwargs : Optional [Dict [str , Any ]] = None ,
2534 overwrite : bool = False ,
35+ create_if_missing : bool = False ,
2636 entry_scope : Scope = None ,
37+ redis_url : str = "redis://localhost:6379" ,
38+ sdk_options : Optional [LangCacheOptions ] = None ,
2739 ** kwargs ,
2840 ):
2941 """Initialize a LangCache client.
3042
43+ TODO: What's the difference between the name and ID of a LangCache cache?
44+ You can't get a cache by name, only ID...
45+
3146 Args:
32- redis_client: A Redis client instance .
33- name: Name of the cache.
47+ name: Name of the cache index to use in LangCache .
48+ cache_id: ID of an existing cache to use .
3449 distance_threshold: Threshold for semantic similarity (0.0 to 1.0).
3550 ttl: Time-to-live for cache entries in seconds.
36- redis_url: URL for Redis connection if no client is provided.
37- connection_kwargs: Additional Redis connection parameters.
3851 overwrite: Whether to overwrite an existing cache with the same name.
3952 entry_scope: Optional scope for cache entries.
53+ redis_url: URL for the Redis instance.
54+ sdk_options: Optional configuration options for the LangCache SDK.
4055 """
41- if connection_kwargs is None :
42- connection_kwargs = {}
43-
44- super ().__init__ (
45- name = name ,
46- ttl = ttl ,
47- redis_client = redis_client ,
48- redis_url = redis_url ,
49- connection_kwargs = connection_kwargs ,
50- )
56+ super ().__init__ (name = name , ttl = ttl , ** kwargs )
5157
5258 self ._name = name
53- self ._redis_client = redis_client
54- self ._redis_url = redis_url
5559 self ._distance_threshold = distance_threshold
5660 self ._ttl = ttl
57- self ._cache_id = name
5861 self ._entry_scope = entry_scope
59- # Initialize LangCache SDK client
60- self ._api = LangCacheSDK (server_url = redis_url , client = redis_client )
61-
62- # Create cache if it doesn't exist or if overwrite is True
63- try :
64- existing_cache = self ._api .cache .get (cache_id = self ._cache_id )
65- if not existing_cache and overwrite :
66- self ._api .cache .create (
67- index_name = self ._name ,
68- redis_urls = [self ._redis_url ],
69- )
70- except Exception :
71- # If the cache doesn't exist, create it
72- if overwrite :
73- self ._api .cache .create (
74- index_name = self ._name ,
75- redis_urls = [self ._redis_url ],
76- )
62+
63+ if sdk_options is None :
64+ sdk_options = LangCacheOptions ()
65+
66+ options_dict = sdk_options .model_dump (exclude_none = True )
67+ self ._api = LangCacheSDK (** options_dict )
68+
69+ # Try to find the cache if given an ID. If not found and
70+ # create_if_missing is False, raise an error. Otherwise, try to create a
71+ # new cache.
72+ if cache_id :
73+ try :
74+ self ._api .cache .get (cache_id = cache_id )
75+ except (APIError , APIErrorResponse ):
76+ if not create_if_missing :
77+ raise CacheNotFound (f"LangCache cache with ID { cache_id } not found" )
78+ # We can't pass the cache ID to the create method, so reset it
79+ cache_id = None
80+ if not cache_id :
81+ self ._cache_id = self ._api .cache .create (
82+ index_name = self ._name ,
83+ redis_urls = [redis_url ],
84+ overwrite_if_exists = overwrite ,
85+ ).cache_id
86+ assert cache_id
87+ self ._cache_id = cache_id
7788
7889 @property
7990 def distance_threshold (self ) -> float :
@@ -91,7 +102,7 @@ def set_threshold(self, distance_threshold: float) -> None:
91102 """
92103 if not 0 <= float (distance_threshold ) <= 2 :
93104 raise ValueError ("Distance threshold must be between 0 and 2" )
94- self ._distance_threshold = float (distance_threshold )
105+ self ._distance_threshold = float (norm_cosine_distance ( distance_threshold ) )
95106
96107 @property
97108 def ttl (self ) -> Optional [int ]:
@@ -238,7 +249,7 @@ def check(
238249 "entry_id" : entry .id ,
239250 "prompt" : entry .prompt ,
240251 "response" : entry .response ,
241- "vector_distance" : entry .similarity ,
252+ "vector_distance" : denorm_cosine_distance ( entry .similarity ) ,
242253 }
243254
244255 # Add metadata if available
@@ -334,7 +345,7 @@ async def acheck(
334345 "entry_id" : entry .id ,
335346 "prompt" : entry .prompt ,
336347 "response" : entry .response ,
337- "vector_distance" : entry .similarity ,
348+ "vector_distance" : denorm_cosine_distance ( entry .similarity ) ,
338349 }
339350
340351 # Add metadata if available
@@ -408,6 +419,7 @@ def store(
408419 scope = entry_scope if entry_scope is not None else self ._entry_scope
409420
410421 # Convert TTL from seconds to milliseconds for LangCache SDK
422+ ttl = ttl or self ._ttl
411423 ttl_millis = ttl * 1000 if ttl is not None else None
412424
413425 # Store the entry
@@ -576,7 +588,7 @@ async def aupdate(self, key: str, **kwargs) -> None:
576588 )
577589
578590 def disconnect (self ) -> None :
579- """Disconnect from Redis ."""
591+ """Disconnect from LangCache ."""
580592 if (
581593 hasattr (self ._api .sdk_configuration , "client" )
582594 and self ._api .sdk_configuration .client
@@ -585,11 +597,6 @@ def disconnect(self) -> None:
585597
586598 async def adisconnect (self ) -> None :
587599 """Async disconnect from Redis."""
588- if (
589- hasattr (self ._api .sdk_configuration , "client" )
590- and self ._api .sdk_configuration .client
591- ):
592- self ._api .sdk_configuration .client .close ()
593600 if (
594601 hasattr (self ._api .sdk_configuration , "async_client" )
595602 and self ._api .sdk_configuration .async_client
0 commit comments