Skip to content

Commit 67f1ee3

Browse files
committed
Fix count_memories method and test fallback behavior
- Improve count_memories to use Redis FT.SEARCH for efficiency - Add fallback method for counting when direct search fails - Fix test to account for search optimization fallback behavior - Use proper MemoryRecordResult with required dist field in test
1 parent 5bb4915 commit 67f1ee3

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

agent_memory_server/vectorstore_adapter.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,71 @@ async def count_memories(
11631163
user_id: str | None = None,
11641164
session_id: str | None = None,
11651165
) -> int:
1166-
"""Count memories using the same approach as search_memories for consistency."""
1166+
"""Count memories using Redis FT.SEARCH for efficiency."""
1167+
try:
1168+
# Get the RedisVL index for direct Redis operations
1169+
index = self._get_vectorstore_index()
1170+
if index is None:
1171+
logger.warning(
1172+
"RedisVL index not available, falling back to vector search"
1173+
)
1174+
# Fallback to vector search approach
1175+
return await self._count_memories_fallback(
1176+
namespace, user_id, session_id
1177+
)
1178+
1179+
# Build filter expression
1180+
filters = []
1181+
if namespace:
1182+
namespace_filter = Namespace(eq=namespace).to_filter()
1183+
filters.append(namespace_filter)
1184+
if user_id:
1185+
user_filter = UserId(eq=user_id).to_filter()
1186+
filters.append(user_filter)
1187+
if session_id:
1188+
session_filter = SessionId(eq=session_id).to_filter()
1189+
filters.append(session_filter)
1190+
1191+
# Combine filters with AND logic
1192+
redis_filter = None
1193+
if filters:
1194+
if len(filters) == 1:
1195+
redis_filter = filters[0]
1196+
else:
1197+
redis_filter = reduce(lambda x, y: x & y, filters)
1198+
1199+
# Use Redis FT.SEARCH with LIMIT 0 0 to get count only
1200+
from redisvl.query import FilterQuery
1201+
1202+
if redis_filter is not None:
1203+
# Use FilterQuery for non-vector search
1204+
query = FilterQuery(filter_expression=redis_filter, num_results=0)
1205+
else:
1206+
# Match all documents
1207+
query = FilterQuery(filter_expression="*", num_results=0)
1208+
1209+
# Execute the query to get count
1210+
if hasattr(index, "asearch"):
1211+
results = await index.asearch(query)
1212+
else:
1213+
results = index.search(query)
1214+
1215+
return results.total
1216+
1217+
except Exception as e:
1218+
logger.warning(
1219+
f"Error counting memories with Redis search, falling back: {e}"
1220+
)
1221+
# Fallback to vector search approach
1222+
return await self._count_memories_fallback(namespace, user_id, session_id)
1223+
1224+
async def _count_memories_fallback(
1225+
self,
1226+
namespace: str | None = None,
1227+
user_id: str | None = None,
1228+
session_id: str | None = None,
1229+
) -> int:
1230+
"""Fallback method for counting memories using vector search."""
11671231
try:
11681232
# Use the same filter approach as search_memories
11691233
filters = []
@@ -1186,10 +1250,9 @@ async def count_memories(
11861250
else:
11871251
redis_filter = reduce(lambda x, y: x & y, filters)
11881252

1189-
# Use the same search method as search_memories but for counting
1190-
# We use a generic query to match all indexed content
1253+
# Use a simple text query that should match most content
11911254
search_results = await self.vectorstore.asimilarity_search(
1192-
query="", # Empty query to match all content
1255+
query="memory", # Simple query that should match content
11931256
filter=redis_filter,
11941257
k=10000, # Large number to get all results
11951258
)

tests/test_long_term_memory.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,19 @@ async def test_search_passes_all_parameters_correctly(
10311031
"""Test that all search parameters are passed correctly to the adapter."""
10321032
# Mock the vectorstore adapter
10331033
mock_adapter = AsyncMock()
1034+
# Return some results to avoid fallback behavior when distance_threshold is set
10341035
mock_adapter.search_memories.return_value = MemoryRecordResults(
1035-
total=0, memories=[]
1036+
total=1,
1037+
memories=[
1038+
MemoryRecordResult(
1039+
id="test-id",
1040+
text="test memory",
1041+
session_id="test-session",
1042+
user_id="test-user",
1043+
namespace="test-namespace",
1044+
dist=0.1, # Required field for MemoryRecordResult
1045+
)
1046+
],
10361047
)
10371048
mock_get_adapter.return_value = mock_adapter
10381049

0 commit comments

Comments
 (0)