Skip to content

Commit 9a50d15

Browse files
committed
Fix ID vs. name usage
1 parent 26b960d commit 9a50d15

File tree

4 files changed

+266
-73
lines changed

4 files changed

+266
-73
lines changed

redisvl/extensions/cache/llm/langcache_api.py

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,90 @@
1+
import asyncio
12
import json
23
from typing import Any, Dict, List, Optional, Union
34

5+
from langcache import APIError, APIErrorResponse
46
from langcache import LangCache as LangCacheSDK
57
from langcache.models import CacheEntryScope, CacheEntryScopeTypedDict
68

79
from redisvl.extensions.cache.llm.base import BaseLLMCache
10+
from redisvl.extensions.cache.llm.schema import LangCacheOptions
811
from redisvl.query.filter import FilterExpression
9-
from redisvl.utils.utils import current_timestamp
12+
from redisvl.utils.utils import (
13+
current_timestamp,
14+
denorm_cosine_distance,
15+
norm_cosine_distance,
16+
)
1017

1118
Scope = Optional[Union[CacheEntryScope, CacheEntryScopeTypedDict]]
1219

1320

21+
class CacheNotFound(RuntimeError):
22+
"""The specified LangCache cache was not found."""
23+
24+
1425
class LangCache(BaseLLMCache):
1526
"""Redis LangCache Service: API for managing a Redis LangCache"""
1627

1728
def __init__(
1829
self,
19-
redis_client=None,
2030
name: str = "llmcache",
31+
cache_id: Optional[str] = None,
2132
distance_threshold: float = 0.1,
2233
ttl: Optional[int] = None,
23-
redis_url: str = "redis://localhost:6379",
24-
connection_kwargs: Optional[Dict[str, Any]] = None,
2534
overwrite: bool = False,
35+
create_if_missing: bool = False,
2636
entry_scope: Scope = None,
37+
redis_url: str = "redis://localhost:6379",
38+
sdk_options: Optional[LangCacheOptions] = None,
2739
**kwargs,
2840
):
2941
"""Initialize a LangCache client.
3042
43+
TODO: What's the difference between the name and ID of a LangCache cache?
44+
You can't get a cache by name, only ID...
45+
3146
Args:
32-
redis_client: A Redis client instance.
33-
name: Name of the cache.
47+
name: Name of the cache index to use in LangCache.
48+
cache_id: ID of an existing cache to use.
3449
distance_threshold: Threshold for semantic similarity (0.0 to 1.0).
3550
ttl: Time-to-live for cache entries in seconds.
36-
redis_url: URL for Redis connection if no client is provided.
37-
connection_kwargs: Additional Redis connection parameters.
3851
overwrite: Whether to overwrite an existing cache with the same name.
3952
entry_scope: Optional scope for cache entries.
53+
redis_url: URL for the Redis instance.
54+
sdk_options: Optional configuration options for the LangCache SDK.
4055
"""
41-
if connection_kwargs is None:
42-
connection_kwargs = {}
43-
44-
super().__init__(
45-
name=name,
46-
ttl=ttl,
47-
redis_client=redis_client,
48-
redis_url=redis_url,
49-
connection_kwargs=connection_kwargs,
50-
)
56+
super().__init__(name=name, ttl=ttl, **kwargs)
5157

5258
self._name = name
53-
self._redis_client = redis_client
54-
self._redis_url = redis_url
5559
self._distance_threshold = distance_threshold
5660
self._ttl = ttl
57-
self._cache_id = name
5861
self._entry_scope = entry_scope
59-
# Initialize LangCache SDK client
60-
self._api = LangCacheSDK(server_url=redis_url, client=redis_client)
61-
62-
# Create cache if it doesn't exist or if overwrite is True
63-
try:
64-
existing_cache = self._api.cache.get(cache_id=self._cache_id)
65-
if not existing_cache and overwrite:
66-
self._api.cache.create(
67-
index_name=self._name,
68-
redis_urls=[self._redis_url],
69-
)
70-
except Exception:
71-
# If the cache doesn't exist, create it
72-
if overwrite:
73-
self._api.cache.create(
74-
index_name=self._name,
75-
redis_urls=[self._redis_url],
76-
)
62+
63+
if sdk_options is None:
64+
sdk_options = LangCacheOptions()
65+
66+
options_dict = sdk_options.model_dump(exclude_none=True)
67+
self._api = LangCacheSDK(**options_dict)
68+
69+
# Try to find the cache if given an ID. If not found and
70+
# create_if_missing is False, raise an error. Otherwise, try to create a
71+
# new cache.
72+
if cache_id:
73+
try:
74+
self._api.cache.get(cache_id=cache_id)
75+
except (APIError, APIErrorResponse):
76+
if not create_if_missing:
77+
raise CacheNotFound(f"LangCache cache with ID {cache_id} not found")
78+
# We can't pass the cache ID to the create method, so reset it
79+
cache_id = None
80+
if not cache_id:
81+
self._cache_id = self._api.cache.create(
82+
index_name=self._name,
83+
redis_urls=[redis_url],
84+
overwrite_if_exists=overwrite,
85+
).cache_id
86+
assert cache_id
87+
self._cache_id = cache_id
7788

7889
@property
7990
def distance_threshold(self) -> float:
@@ -91,7 +102,7 @@ def set_threshold(self, distance_threshold: float) -> None:
91102
"""
92103
if not 0 <= float(distance_threshold) <= 2:
93104
raise ValueError("Distance threshold must be between 0 and 2")
94-
self._distance_threshold = float(distance_threshold)
105+
self._distance_threshold = float(norm_cosine_distance(distance_threshold))
95106

96107
@property
97108
def ttl(self) -> Optional[int]:
@@ -238,7 +249,7 @@ def check(
238249
"entry_id": entry.id,
239250
"prompt": entry.prompt,
240251
"response": entry.response,
241-
"vector_distance": entry.similarity,
252+
"vector_distance": denorm_cosine_distance(entry.similarity),
242253
}
243254

244255
# Add metadata if available
@@ -334,7 +345,7 @@ async def acheck(
334345
"entry_id": entry.id,
335346
"prompt": entry.prompt,
336347
"response": entry.response,
337-
"vector_distance": entry.similarity,
348+
"vector_distance": denorm_cosine_distance(entry.similarity),
338349
}
339350

340351
# Add metadata if available
@@ -408,6 +419,7 @@ def store(
408419
scope = entry_scope if entry_scope is not None else self._entry_scope
409420

410421
# Convert TTL from seconds to milliseconds for LangCache SDK
422+
ttl = ttl or self._ttl
411423
ttl_millis = ttl * 1000 if ttl is not None else None
412424

413425
# Store the entry
@@ -576,7 +588,7 @@ async def aupdate(self, key: str, **kwargs) -> None:
576588
)
577589

578590
def disconnect(self) -> None:
579-
"""Disconnect from Redis."""
591+
"""Disconnect from LangCache."""
580592
if (
581593
hasattr(self._api.sdk_configuration, "client")
582594
and self._api.sdk_configuration.client
@@ -585,11 +597,6 @@ def disconnect(self) -> None:
585597

586598
async def adisconnect(self) -> None:
587599
"""Async disconnect from Redis."""
588-
if (
589-
hasattr(self._api.sdk_configuration, "client")
590-
and self._api.sdk_configuration.client
591-
):
592-
self._api.sdk_configuration.client.close()
593600
if (
594601
hasattr(self._api.sdk_configuration, "async_client")
595602
and self._api.sdk_configuration.async_client

redisvl/extensions/cache/llm/schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from typing import Any, Dict, List, Optional
22

3+
from langcache.httpclient import AsyncHttpClient, HttpClient
4+
from langcache.types import UNSET, OptionalNullable
5+
from langcache.utils.logger import Logger
6+
from langcache.utils.retries import RetryConfig
37
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
48

59
from redisvl.extensions.constants import (
@@ -110,6 +114,29 @@ def to_dict(self) -> Dict[str, Any]:
110114
return data
111115

112116

117+
class LangCacheOptions(BaseModel):
118+
"""Options for the LangCache SDK client."""
119+
120+
server_idx: Optional[int] = None
121+
"""The index of the server to use for all methods."""
122+
server_url: Optional[str] = None
123+
"""The server URL to use for all methods."""
124+
url_params: Optional[Dict[str, str]] = None
125+
"""Parameters to optionally template the server URL with."""
126+
client: Optional[HttpClient] = None
127+
"""The HTTP client to use for all synchronous methods."""
128+
async_client: Optional[AsyncHttpClient] = None
129+
"""The Async HTTP client to use for all asynchronous methods."""
130+
retry_config: OptionalNullable[RetryConfig] = UNSET
131+
"""The retry configuration to use for all supported methods."""
132+
timeout_ms: Optional[int] = None
133+
"""Optional request timeout applied to each operation in milliseconds."""
134+
debug_logger: Optional[Any] = None
135+
"""Optional logger for debugging."""
136+
137+
model_config = ConfigDict(arbitrary_types_allowed=True)
138+
139+
113140
class SemanticCacheIndexSchema(IndexSchema):
114141

115142
@classmethod

0 commit comments

Comments
 (0)