Skip to content

Commit a8fb65c

Browse files
abrookinsclaude
andcommitted
fix: address PR feedback - improve type checking, extract complex logic, update docs
- Add numbers.Number-based type checking with _is_numeric() helper - Extract Redis aggregation logic into separate _search_with_redis_aggregation() method - Add safe _get_vectorstore_index() method to avoid direct _index access - Document hard_age_multiplier parameter in select_ids_for_forgetting docstring - Remove stale TDD comment from test file 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 4c6b1c1 commit a8fb65c

File tree

3 files changed

+139
-97
lines changed

3 files changed

+139
-97
lines changed

agent_memory_server/long_term_memory.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hashlib
22
import json
33
import logging
4+
import numbers
45
import time
56
from collections.abc import Iterable
67
from datetime import UTC, datetime, timedelta
@@ -1428,6 +1429,11 @@ def combined_score(mem: MemoryRecordResult) -> float:
14281429
return sorted(results, key=combined_score, reverse=True)
14291430

14301431

1432+
def _is_numeric(value: Any) -> bool:
1433+
"""Check if a value is numeric (int, float, or other number type)."""
1434+
return isinstance(value, numbers.Number)
1435+
1436+
14311437
def select_ids_for_forgetting(
14321438
results: Iterable[MemoryRecordResult],
14331439
*,
@@ -1442,6 +1448,7 @@ def select_ids_for_forgetting(
14421448
- max_inactive_days: float | None
14431449
- budget: int | None (keep top N by recency score)
14441450
- memory_type_allowlist: set[str] | list[str] | None (only consider these types for deletion)
1451+
- hard_age_multiplier: float (default 12.0) - multiplier for max_age_days to determine extremely old items
14451452
"""
14461453
pinned_ids = pinned_ids or set()
14471454
max_age_days = policy.get("max_age_days")
@@ -1476,9 +1483,7 @@ def select_ids_for_forgetting(
14761483
# - If both thresholds are set, prefer not to delete recently accessed
14771484
# items unless they are extremely old.
14781485
# - Extremely old: age > max_age_days * hard_age_multiplier (default 12x)
1479-
if isinstance(max_age_days, int | float) and isinstance(
1480-
max_inactive_days, int | float
1481-
):
1486+
if _is_numeric(max_age_days) and _is_numeric(max_inactive_days):
14821487
if age_days > float(max_age_days) * hard_age_multiplier:
14831488
to_delete.add(mem.id)
14841489
continue
@@ -1488,10 +1493,8 @@ def select_ids_for_forgetting(
14881493
to_delete.add(mem.id)
14891494
continue
14901495
else:
1491-
ttl_hit = isinstance(max_age_days, int | float) and age_days > float(
1492-
max_age_days
1493-
)
1494-
inactivity_hit = isinstance(max_inactive_days, int | float) and (
1496+
ttl_hit = _is_numeric(max_age_days) and age_days > float(max_age_days)
1497+
inactivity_hit = _is_numeric(max_inactive_days) and (
14951498
inactive_days > float(max_inactive_days)
14961499
)
14971500
if ttl_hit or inactivity_hit:

agent_memory_server/vectorstore_adapter.py

Lines changed: 129 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from langchain_core.embeddings import Embeddings
1414
from langchain_core.vectorstores import VectorStore
1515
from langchain_redis.vectorstores import RedisVectorStore
16-
from redisvl.query import RangeQuery, VectorQuery
1716

1817
from agent_memory_server.filters import (
1918
CreatedAt,
@@ -837,6 +836,127 @@ async def update_memories(self, memories: list[MemoryRecord]) -> int:
837836
added = await self.add_memories(memories)
838837
return len(added)
839838

839+
def _get_vectorstore_index(self):
840+
"""Safely access the underlying RedisVL index from the vectorstore.
841+
842+
Returns:
843+
RedisVL SearchIndex or None if not available
844+
"""
845+
return getattr(self.vectorstore, "_index", None)
846+
847+
async def _search_with_redis_aggregation(
848+
self,
849+
query: str,
850+
redis_filter,
851+
limit: int,
852+
offset: int,
853+
distance_threshold: float | None,
854+
recency_params: dict | None,
855+
) -> MemoryRecordResults:
856+
"""Perform server-side Redis aggregation search with recency scoring.
857+
858+
Args:
859+
query: Search query text
860+
redis_filter: Redis filter expression
861+
limit: Maximum number of results
862+
offset: Offset for pagination
863+
distance_threshold: Distance threshold for range queries
864+
recency_params: Parameters for recency scoring
865+
866+
Returns:
867+
MemoryRecordResults with server-side scored results
868+
869+
Raises:
870+
Exception: If Redis aggregation fails (caller should handle fallback)
871+
"""
872+
from datetime import UTC as _UTC, datetime as _dt
873+
874+
from langchain_core.documents import Document
875+
from redisvl.query import RangeQuery, VectorQuery
876+
877+
from agent_memory_server.utils.redis_query import RecencyAggregationQuery
878+
879+
index = self._get_vectorstore_index()
880+
if index is None:
881+
raise Exception("RedisVL index not available")
882+
883+
# Embed the query text to vector
884+
embedding_vector = self.embeddings.embed_query(query)
885+
886+
# Build base KNN query (hybrid)
887+
if distance_threshold is not None:
888+
knn = RangeQuery(
889+
vector=embedding_vector,
890+
vector_field_name="vector",
891+
filter_expression=redis_filter,
892+
distance_threshold=float(distance_threshold),
893+
num_results=limit,
894+
)
895+
else:
896+
knn = VectorQuery(
897+
vector=embedding_vector,
898+
vector_field_name="vector",
899+
filter_expression=redis_filter,
900+
num_results=limit,
901+
)
902+
903+
# Aggregate with APPLY/SORTBY boosted score via helper
904+
now_ts = int(_dt.now(_UTC).timestamp())
905+
agg = (
906+
RecencyAggregationQuery.from_vector_query(
907+
knn, filter_expression=redis_filter
908+
)
909+
.load_default_fields()
910+
.apply_recency(now_ts=now_ts, params=recency_params or {})
911+
.sort_by_boosted_desc()
912+
.paginate(offset, limit)
913+
)
914+
915+
raw = (
916+
await index.aaggregate(agg)
917+
if hasattr(index, "aaggregate")
918+
else index.aggregate(agg) # type: ignore
919+
)
920+
921+
rows = getattr(raw, "rows", raw) or []
922+
memory_results: list[MemoryRecordResult] = []
923+
for row in rows:
924+
fields = getattr(row, "__dict__", None) or row
925+
metadata = {
926+
k: fields.get(k)
927+
for k in [
928+
"id_",
929+
"session_id",
930+
"user_id",
931+
"namespace",
932+
"created_at",
933+
"last_accessed",
934+
"updated_at",
935+
"pinned",
936+
"access_count",
937+
"topics",
938+
"entities",
939+
"memory_hash",
940+
"discrete_memory_extracted",
941+
"memory_type",
942+
"persisted_at",
943+
"extracted_from",
944+
"event_date",
945+
]
946+
if k in fields
947+
}
948+
text_val = fields.get("text", "")
949+
score = fields.get("__vector_score", 1.0) or 1.0
950+
doc_obj = Document(page_content=text_val, metadata=metadata)
951+
memory_results.append(self.document_to_memory(doc_obj, float(score)))
952+
953+
next_offset = offset + limit if len(memory_results) == limit else None
954+
return MemoryRecordResults(
955+
memories=memory_results[:limit],
956+
total=offset + len(memory_results),
957+
next_offset=next_offset,
958+
)
959+
840960
async def search_memories(
841961
self,
842962
query: str,
@@ -900,94 +1020,14 @@ async def search_memories(
9001020
# If server-side recency is requested, attempt RedisVL query first (DB-level path)
9011021
if server_side_recency:
9021022
try:
903-
index = getattr(self.vectorstore, "_index", None)
904-
if index is not None:
905-
# Embed the query text to vector
906-
embedding_vector = self.embeddings.embed_query(query)
907-
908-
# Build base KNN query (hybrid)
909-
if distance_threshold is not None:
910-
knn = RangeQuery(
911-
vector=embedding_vector,
912-
vector_field_name="vector",
913-
filter_expression=redis_filter,
914-
distance_threshold=float(distance_threshold),
915-
num_results=limit,
916-
)
917-
else:
918-
knn = VectorQuery(
919-
vector=embedding_vector,
920-
vector_field_name="vector",
921-
filter_expression=redis_filter,
922-
num_results=limit,
923-
)
924-
925-
# Aggregate with APPLY/SORTBY boosted score via helper
926-
from datetime import UTC as _UTC, datetime as _dt
927-
928-
from agent_memory_server.utils.redis_query import (
929-
RecencyAggregationQuery,
930-
)
931-
932-
now_ts = int(_dt.now(_UTC).timestamp())
933-
agg = (
934-
RecencyAggregationQuery.from_vector_query(
935-
knn, filter_expression=redis_filter
936-
)
937-
.load_default_fields()
938-
.apply_recency(now_ts=now_ts, params=recency_params or {})
939-
.sort_by_boosted_desc()
940-
.paginate(offset, limit)
941-
)
942-
943-
raw = (
944-
await index.aaggregate(agg)
945-
if hasattr(index, "aaggregate")
946-
else index.aggregate(agg) # type: ignore
947-
)
948-
949-
rows = getattr(raw, "rows", raw) or []
950-
memory_results: list[MemoryRecordResult] = []
951-
for row in rows:
952-
fields = getattr(row, "__dict__", None) or row
953-
metadata = {
954-
k: fields.get(k)
955-
for k in [
956-
"id_",
957-
"session_id",
958-
"user_id",
959-
"namespace",
960-
"created_at",
961-
"last_accessed",
962-
"updated_at",
963-
"pinned",
964-
"access_count",
965-
"topics",
966-
"entities",
967-
"memory_hash",
968-
"discrete_memory_extracted",
969-
"memory_type",
970-
"persisted_at",
971-
"extracted_from",
972-
"event_date",
973-
]
974-
if k in fields
975-
}
976-
text_val = fields.get("text", "")
977-
score = fields.get("__vector_score", 1.0) or 1.0
978-
doc_obj = Document(page_content=text_val, metadata=metadata)
979-
memory_results.append(
980-
self.document_to_memory(doc_obj, float(score))
981-
)
982-
983-
next_offset = (
984-
offset + limit if len(memory_results) == limit else None
985-
)
986-
return MemoryRecordResults(
987-
memories=memory_results[:limit],
988-
total=offset + len(memory_results),
989-
next_offset=next_offset,
990-
)
1023+
return await self._search_with_redis_aggregation(
1024+
query=query,
1025+
redis_filter=redis_filter,
1026+
limit=limit,
1027+
offset=offset,
1028+
distance_threshold=distance_threshold,
1029+
recency_params=recency_params,
1030+
)
9911031
except Exception as e:
9921032
logger.warning(
9931033
f"RedisVL DB-level recency search failed; falling back to client-side path: {e}"

tests/test_forgetting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from datetime import UTC, datetime, timedelta
22

3-
# TDD: These helpers/functions will be implemented in agent_memory_server.long_term_memory
43
from agent_memory_server.long_term_memory import (
54
rerank_with_recency, # new: pure function
65
score_recency, # new: pure function

0 commit comments

Comments
 (0)