Skip to content

Commit a08418e

Browse files
committed
LangCacheWrapper: address PR review feedback\n\n- Broaden attributes type in check/acheck to Optional[Dict[str, Any]]\n- Clarify docstrings for distance_threshold conversion for both scales\n- Factor helpers: _similarity_threshold, _build_search_kwargs, _hits_from_response to reduce duplication
1 parent 8c4c742 commit a08418e

File tree

1 file changed

+68
-63
lines changed

1 file changed

+68
-63
lines changed

redisvl/extensions/cache/llm/langcache.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,49 @@ def __init__(
125125
api_key=self._api_key,
126126
)
127127

128+
def _similarity_threshold(
129+
self, distance_threshold: Optional[float]
130+
) -> Optional[float]:
131+
"""Convert a distance threshold to a similarity threshold based on scale.
132+
133+
- If distance_scale == "redis": use norm_cosine_distance (0–2 -> 0–1)
134+
- Otherwise: use (1.0 - distance_threshold) for normalized 0–1 distances
135+
"""
136+
if distance_threshold is None:
137+
return None
138+
if self._distance_scale == "redis":
139+
return norm_cosine_distance(distance_threshold)
140+
return 1.0 - float(distance_threshold)
141+
142+
def _build_search_kwargs(
143+
self,
144+
prompt: str,
145+
similarity_threshold: Optional[float],
146+
attributes: Optional[Dict[str, Any]],
147+
) -> Dict[str, Any]:
148+
kwargs: Dict[str, Any] = {
149+
"prompt": prompt,
150+
"search_strategies": self._search_strategies,
151+
"similarity_threshold": similarity_threshold,
152+
}
153+
if attributes:
154+
kwargs["attributes"] = attributes
155+
return kwargs
156+
157+
def _hits_from_response(
158+
self, response: Any, num_results: int
159+
) -> List[Dict[str, Any]]:
160+
results = response.data if hasattr(response, "data") else []
161+
hits: List[Dict[str, Any]] = []
162+
for result in results[:num_results]:
163+
if hasattr(result, "model_dump"):
164+
result_dict = result.model_dump()
165+
else:
166+
result_dict = dict(result) # type: ignore[arg-type]
167+
hit = self._convert_to_cache_hit(result_dict)
168+
hits.append(hit.to_dict())
169+
return hits
170+
128171
def _convert_to_cache_hit(self, result: Dict[str, Any]) -> CacheHit:
129172
"""Convert a LangCache result to a CacheHit object.
130173
@@ -163,7 +206,7 @@ def check(
163206
return_fields: Optional[List[str]] = None,
164207
filter_expression: Optional[FilterExpression] = None,
165208
distance_threshold: Optional[float] = None,
166-
attributes: Optional[Dict[str, str]] = None,
209+
attributes: Optional[Dict[str, Any]] = None,
167210
) -> List[Dict[str, Any]]:
168211
"""Check the cache for semantically similar prompts.
169212
@@ -174,8 +217,10 @@ def check(
174217
return_fields (Optional[List[str]]): Not used (for compatibility).
175218
filter_expression (Optional[FilterExpression]): Not supported.
176219
distance_threshold (Optional[float]): Maximum distance threshold.
177-
Converted to similarity_threshold (1.0 - distance_threshold).
178-
attributes (Optional[Dict[str, str]]): LangCache attributes to filter by.
220+
Converted to similarity_threshold according to distance_scale:
221+
- If "redis": uses norm_cosine_distance(distance_threshold) ([0,2] → [0,1])
222+
- If "normalized": uses (1.0 - distance_threshold) ([0,1] → [0,1])
223+
attributes (Optional[Dict[str, Any]]): LangCache attributes to filter by.
179224
Note: Attributes must be pre-configured in your LangCache instance.
180225
181226
Returns:
@@ -196,42 +241,20 @@ def check(
196241
# Convert distance threshold to similarity threshold according to configured scale
197242
similarity_threshold = None
198243
if distance_threshold is not None:
199-
if self._distance_scale == "redis":
200-
similarity_threshold = norm_cosine_distance(
201-
distance_threshold
202-
) # [0,2] -> [0,1]
203-
else:
204-
similarity_threshold = 1.0 - float(distance_threshold) # [0,1] -> [0,1]
244+
similarity_threshold = self._similarity_threshold(distance_threshold)
205245

206246
# Search using the LangCache client
207247
# The client itself is the context manager
208-
search_kwargs: Dict[str, Any] = {
209-
"prompt": prompt,
210-
"search_strategies": self._search_strategies,
211-
"similarity_threshold": similarity_threshold,
212-
}
213-
214-
# Add attributes if provided
215-
if attributes:
216-
search_kwargs["attributes"] = attributes
248+
search_kwargs = self._build_search_kwargs(
249+
prompt=prompt,
250+
similarity_threshold=similarity_threshold,
251+
attributes=attributes,
252+
)
217253

218254
response = self._client.search(**search_kwargs)
219255

220256
# Convert results to cache hits
221-
# Response is a SearchResponse Pydantic model with a 'data' attribute
222-
results = response.data if hasattr(response, "data") else []
223-
cache_hits = []
224-
for result in results[:num_results]:
225-
# Convert CacheEntry to dict
226-
result_dict: Dict[str, Any]
227-
if hasattr(result, "model_dump"):
228-
result_dict = result.model_dump()
229-
else:
230-
result_dict = dict(result) # type: ignore[arg-type]
231-
hit = self._convert_to_cache_hit(result_dict)
232-
cache_hits.append(hit.to_dict())
233-
234-
return cache_hits
257+
return self._hits_from_response(response, num_results)
235258

236259
async def acheck(
237260
self,
@@ -241,7 +264,7 @@ async def acheck(
241264
return_fields: Optional[List[str]] = None,
242265
filter_expression: Optional[FilterExpression] = None,
243266
distance_threshold: Optional[float] = None,
244-
attributes: Optional[Dict[str, str]] = None,
267+
attributes: Optional[Dict[str, Any]] = None,
245268
) -> List[Dict[str, Any]]:
246269
"""Async check the cache for semantically similar prompts.
247270
@@ -252,8 +275,10 @@ async def acheck(
252275
return_fields (Optional[List[str]]): Not used (for compatibility).
253276
filter_expression (Optional[FilterExpression]): Not supported.
254277
distance_threshold (Optional[float]): Maximum distance threshold.
255-
Converted to similarity_threshold (1.0 - distance_threshold).
256-
attributes (Optional[Dict[str, str]]): LangCache attributes to filter by.
278+
Converted to similarity_threshold according to distance_scale:
279+
- If "redis": uses norm_cosine_distance(distance_threshold) ([0,2] -> [0,1])
280+
- If "normalized": uses (1.0 - distance_threshold) ([0,1] -> [0,1])
281+
attributes (Optional[Dict[str, Any]]): LangCache attributes to filter by.
257282
Note: Attributes must be pre-configured in your LangCache instance.
258283
259284
Returns:
@@ -274,42 +299,22 @@ async def acheck(
274299
# Convert distance threshold to similarity threshold according to configured scale
275300
similarity_threshold = None
276301
if distance_threshold is not None:
277-
if self._distance_scale == "redis":
278-
similarity_threshold = norm_cosine_distance(
279-
distance_threshold
280-
) # [0,2] -> [0,1]
281-
else:
282-
similarity_threshold = 1.0 - float(distance_threshold) # [0,1] -> [0,1]
302+
similarity_threshold = self._similarity_threshold(distance_threshold)
283303

284304
# Search using the LangCache client (async)
285305
# The client itself is the context manager
286-
search_kwargs: Dict[str, Any] = {
287-
"prompt": prompt,
288-
"search_strategies": self._search_strategies,
289-
"similarity_threshold": similarity_threshold,
290-
}
306+
search_kwargs = self._build_search_kwargs(
307+
prompt=prompt,
308+
similarity_threshold=similarity_threshold,
309+
attributes=attributes,
310+
)
291311

292-
# Add attributes if provided
293-
if attributes:
294-
search_kwargs["attributes"] = attributes
312+
# Add attributes if provided (already handled by builder)
295313

296314
response = await self._client.search_async(**search_kwargs)
297315

298316
# Convert results to cache hits
299-
# Response is a SearchResponse Pydantic model with a 'data' attribute
300-
results = response.data if hasattr(response, "data") else []
301-
cache_hits = []
302-
for result in results[:num_results]:
303-
# Convert CacheEntry to dict
304-
result_dict: Dict[str, Any]
305-
if hasattr(result, "model_dump"):
306-
result_dict = result.model_dump()
307-
else:
308-
result_dict = dict(result) # type: ignore[arg-type]
309-
hit = self._convert_to_cache_hit(result_dict)
310-
cache_hits.append(hit.to_dict())
311-
312-
return cache_hits
317+
return self._hits_from_response(response, num_results)
313318

314319
def store(
315320
self,

0 commit comments

Comments
 (0)