Skip to content

Commit 481216f

Browse files
bsboddenowais
andcommitted
feat(cache): add warnings when using sync methods with async-only Redis client (#391)
Co-authored-by: owais <[email protected]>
1 parent 76c74a0 commit 481216f

File tree

8 files changed

+389
-11
lines changed

8 files changed

+389
-11
lines changed

redisvl/extensions/cache/embeddings/embeddings.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
class EmbeddingsCache(BaseCache):
1515
"""Embeddings Cache for storing embedding vectors with exact key matching."""
1616

17+
_warning_shown: bool = False # Class-level flag to prevent warning spam
18+
1719
def __init__(
1820
self,
1921
name: str = "embedcache",
@@ -124,6 +126,14 @@ def _process_cache_data(
124126
cache_hit = CacheEntry(**convert_bytes(data))
125127
return cache_hit.model_dump(exclude_none=True)
126128

129+
def _should_warn_for_async_only(self) -> bool:
130+
"""Check if only async client is available (no sync client).
131+
132+
Returns:
133+
bool: True if only async client is available (no sync client).
134+
"""
135+
return self._owns_redis_client is False and self._redis_client is None
136+
127137
def get(
128138
self,
129139
text: str,
@@ -167,6 +177,14 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:
167177
168178
embedding_data = cache.get_by_key("embedcache:1234567890abcdef")
169179
"""
180+
if self._should_warn_for_async_only():
181+
if not EmbeddingsCache._warning_shown:
182+
logger.warning(
183+
"EmbeddingsCache initialized with async_redis_client only. "
184+
"Use async methods (aget_by_key) instead of sync methods (get_by_key)."
185+
)
186+
EmbeddingsCache._warning_shown = True
187+
170188
client = self._get_redis_client()
171189

172190
# Get all fields
@@ -202,6 +220,14 @@ def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
202220
if not keys:
203221
return []
204222

223+
if self._should_warn_for_async_only():
224+
if not EmbeddingsCache._warning_shown:
225+
logger.warning(
226+
"EmbeddingsCache initialized with async_redis_client only. "
227+
"Use async methods (amget_by_keys) instead of sync methods (mget_by_keys)."
228+
)
229+
EmbeddingsCache._warning_shown = True
230+
205231
client = self._get_redis_client()
206232

207233
with client.pipeline(transaction=False) as pipeline:
@@ -283,6 +309,14 @@ def set(
283309
text, model_name, embedding, metadata
284310
)
285311

312+
if self._should_warn_for_async_only():
313+
if not EmbeddingsCache._warning_shown:
314+
logger.warning(
315+
"EmbeddingsCache initialized with async_redis_client only. "
316+
"Use async methods (aset) instead of sync methods (set)."
317+
)
318+
EmbeddingsCache._warning_shown = True
319+
286320
# Store in Redis
287321
client = self._get_redis_client()
288322
client.hset(name=key, mapping=cache_entry) # type: ignore
@@ -333,6 +367,14 @@ def mset(
333367
if not items:
334368
return []
335369

370+
if self._should_warn_for_async_only():
371+
if not EmbeddingsCache._warning_shown:
372+
logger.warning(
373+
"EmbeddingsCache initialized with async_redis_client only. "
374+
"Use async methods (amset) instead of sync methods (mset)."
375+
)
376+
EmbeddingsCache._warning_shown = True
377+
336378
client = self._get_redis_client()
337379
keys = []
338380

redisvl/index/index.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,10 @@ def name(self) -> str:
245245
@property
246246
def prefix(self) -> str:
247247
"""The optional key prefix that comes before a unique key value in
248-
forming a Redis key."""
249-
return self.schema.index.prefix
248+
forming a Redis key. If multiple prefixes are configured, returns the
249+
first one."""
250+
prefix = self.schema.index.prefix
251+
return prefix[0] if isinstance(prefix, list) else prefix
250252

251253
@property
252254
def key_separator(self) -> str:
@@ -329,7 +331,7 @@ def key(self, id: str) -> str:
329331
"""
330332
return self._storage._key(
331333
id=id,
332-
prefix=self.schema.index.prefix,
334+
prefix=self.prefix,
333335
key_separator=self.schema.index.key_separator,
334336
)
335337

redisvl/index/storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,13 @@ def _create_key(self, obj: Dict[str, Any], id_field: Optional[str] = None) -> st
114114
except KeyError:
115115
raise ValueError(f"Key field {id_field} not found in record {obj}")
116116

117+
# Normalize prefix: use first prefix if multiple are configured
118+
prefix = self.index_schema.index.prefix
119+
normalized_prefix = prefix[0] if isinstance(prefix, list) else prefix
120+
117121
return self._key(
118122
key_value,
119-
prefix=self.index_schema.index.prefix,
123+
prefix=normalized_prefix,
120124
key_separator=self.index_schema.index.key_separator,
121125
)
122126

redisvl/redis/connection.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,31 +133,73 @@ def convert_index_info_to_schema(index_info: Dict[str, Any]) -> Dict[str, Any]:
133133
Dict[str, Any]: Schema dictionary.
134134
"""
135135
index_name = index_info["index_name"]
136-
prefixes = index_info["index_definition"][3][0]
136+
prefixes = index_info["index_definition"][3]
137+
# Normalize single-element prefix lists to string for backward compatibility
138+
if isinstance(prefixes, list) and len(prefixes) == 1:
139+
prefixes = prefixes[0]
137140
storage_type = index_info["index_definition"][1].lower()
138141

139142
index_fields = index_info["attributes"]
140143

141144
def parse_vector_attrs(attrs):
142145
# Parse vector attributes from Redis FT.INFO output
143-
# Attributes start at position 6 as key-value pairs
146+
# Format varies significantly between Redis versions:
147+
# - Redis 6.2.6-v9: [... "VECTOR"] - no params returned by FT.INFO
148+
# - Redis 6.2.x: [... "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", ...]
149+
# Position 6: algorithm value (e.g., "FLAT" or "HNSW")
150+
# Position 7: param count
151+
# Position 8+: key-value pairs
152+
# - Redis 7.x+: [... "VECTOR", "ALGORITHM", "FLAT", "TYPE", "FLOAT32", "DIM", "3", ...]
153+
# Position 6+: all key-value pairs
154+
155+
# Check if we have any attributes beyond the type declaration
156+
if len(attrs) <= 6:
157+
# Redis 6.2.6-v9 or similar: no vector params in FT.INFO
158+
# Return None to signal we can't parse this field properly
159+
return None
160+
144161
vector_attrs = {}
162+
start_pos = 6
163+
164+
# Detect format: if position 6 looks like an algorithm value (not a key),
165+
# we're dealing with the older format
166+
if len(attrs) > 6:
167+
pos6_str = str(attrs[6]).upper()
168+
# Check if position 6 is an algorithm value (FLAT, HNSW) vs a key (ALGORITHM, TYPE, DIM)
169+
if pos6_str in ("FLAT", "HNSW"):
170+
# Old format (Redis 6.2.x): position 6 is algorithm value, position 7 is param count
171+
# Store the algorithm
172+
vector_attrs["algorithm"] = pos6_str
173+
# Skip to position 8 where key-value pairs start
174+
start_pos = 8
175+
145176
try:
146-
for i in range(6, len(attrs), 2):
177+
for i in range(start_pos, len(attrs), 2):
147178
if i + 1 < len(attrs):
148179
key = str(attrs[i]).lower()
149180
vector_attrs[key] = attrs[i + 1]
150181
except (IndexError, TypeError, ValueError):
182+
# Silently continue - we'll validate required fields below
151183
pass
152184

153185
# Normalize to expected field names
154186
normalized = {}
155187

156-
# Handle dims/dim field
188+
# Handle dims/dim field - REQUIRED for vector fields
157189
if "dim" in vector_attrs:
158190
normalized["dims"] = int(vector_attrs.pop("dim"))
159191
elif "dims" in vector_attrs:
160192
normalized["dims"] = int(vector_attrs["dims"])
193+
else:
194+
# If dims is missing from normal parsing, try scanning the raw attrs
195+
# This handles edge cases where the format is unexpected
196+
for i in range(6, len(attrs) - 1):
197+
if str(attrs[i]).upper() in ("DIM", "DIMS"):
198+
try:
199+
normalized["dims"] = int(attrs[i + 1])
200+
break
201+
except (ValueError, IndexError):
202+
pass
161203

162204
# Handle distance_metric field
163205
if "distance_metric" in vector_attrs:
@@ -178,10 +220,18 @@ def parse_vector_attrs(attrs):
178220
normalized["datatype"] = vector_attrs["data_type"].lower()
179221
elif "datatype" in vector_attrs:
180222
normalized["datatype"] = vector_attrs["datatype"].lower()
223+
elif "type" in vector_attrs:
224+
# Sometimes it's just "type" instead of "data_type"
225+
normalized["datatype"] = vector_attrs["type"].lower()
181226
else:
182227
# Default to float32 if missing
183228
normalized["datatype"] = "float32"
184229

230+
# Validate that we have required dims
231+
if "dims" not in normalized:
232+
# Could not parse dims - this field is not properly supported
233+
return None
234+
185235
return normalized
186236

187237
def parse_attrs(attrs, field_type=None):
@@ -234,7 +284,12 @@ def parse_attrs(attrs, field_type=None):
234284
field["path"] = field_attrs[1]
235285
# parse field attrs
236286
if field_attrs[5] == "VECTOR":
237-
field["attrs"] = parse_vector_attrs(field_attrs)
287+
attrs = parse_vector_attrs(field_attrs)
288+
if attrs is None:
289+
# Vector field attributes cannot be parsed on this Redis version
290+
# Skip this field - it cannot be properly reconstructed
291+
continue
292+
field["attrs"] = attrs
238293
else:
239294
field["attrs"] = parse_attrs(field_attrs, field_type=field_attrs[5])
240295
# append field

redisvl/schema/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class IndexInfo(BaseModel):
5858

5959
name: str
6060
"""The unique name of the index."""
61-
prefix: str = "rvl"
62-
"""The prefix used for Redis keys associated with this index."""
61+
prefix: Union[str, List[str]] = "rvl"
62+
"""The prefix(es) used for Redis keys associated with this index. Can be a single string or a list of strings."""
6363
key_separator: str = ":"
6464
"""The separator character used in designing Redis keys."""
6565
storage_type: StorageType = StorageType.HASH
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Test warning behavior when using sync methods with async-only client."""
2+
3+
import logging
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from redis import Redis
8+
9+
from redisvl.extensions.cache.embeddings import EmbeddingsCache
10+
11+
12+
@pytest.fixture(autouse=True)
13+
def reset_warning_flag():
14+
"""Reset the warning flag before each test to ensure test isolation."""
15+
EmbeddingsCache._warning_shown = False
16+
yield
17+
# Optionally reset after test as well for cleanup
18+
EmbeddingsCache._warning_shown = False
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_sync_methods_warn_with_async_only_client(async_client, caplog):
23+
"""Test that sync methods warn when only async client is provided."""
24+
# Initialize EmbeddingsCache with only async_redis_client
25+
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)
26+
27+
# Mock _get_redis_client to prevent actual connection attempt
28+
with patch.object(cache, "_get_redis_client") as mock_get_client:
29+
# Mock the Redis client methods that would be called
30+
mock_client = mock_get_client.return_value
31+
mock_client.hgetall.return_value = {} # Empty result for get_by_key
32+
mock_client.hset.return_value = 1 # Success for set
33+
34+
# Capture log warnings
35+
with caplog.at_level(logging.WARNING):
36+
# First sync method call should warn
37+
_ = cache.get_by_key("test_key")
38+
39+
# Check warning was logged
40+
assert len(caplog.records) == 1
41+
assert (
42+
"initialized with async_redis_client only" in caplog.records[0].message
43+
)
44+
assert "Use async methods" in caplog.records[0].message
45+
46+
# Clear captured logs
47+
caplog.clear()
48+
49+
# Second sync method call should NOT warn (flag prevents spam)
50+
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])
51+
52+
# Should not have logged another warning
53+
assert len(caplog.records) == 0
54+
55+
56+
def test_no_warning_with_sync_client(redis_url):
57+
"""Test that no warning is shown when sync client is provided."""
58+
# Create sync redis client from redis_url
59+
sync_client = Redis.from_url(redis_url)
60+
61+
try:
62+
# Initialize EmbeddingsCache with sync_redis_client
63+
cache = EmbeddingsCache(name="test_cache", redis_client=sync_client)
64+
65+
with patch("redisvl.utils.log.get_logger") as mock_logger:
66+
# Sync methods should not warn
67+
_ = cache.get_by_key("test_key")
68+
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])
69+
70+
# No warnings should have been logged
71+
mock_logger.return_value.warning.assert_not_called()
72+
finally:
73+
sync_client.close()
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_async_methods_no_warning(async_client):
78+
"""Test that async methods don't trigger warnings."""
79+
# Initialize EmbeddingsCache with only async_redis_client
80+
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)
81+
82+
with patch("redisvl.utils.log.get_logger") as mock_logger:
83+
# Async methods should not warn
84+
_ = await cache.aget_by_key("test_key")
85+
_ = await cache.aset(text="test", model_name="model", embedding=[0.1, 0.2])
86+
87+
# No warnings should have been logged
88+
mock_logger.return_value.warning.assert_not_called()

0 commit comments

Comments
 (0)