Skip to content

Commit 83d5abb

Browse files
abrookinsclaude
andcommitted
refactor: improve code quality and remove duplication in vectorstore adapter
- Remove duplicate client-side reranking logic by extracting shared helper method - Use SECONDS_PER_DAY constant instead of magic number 86400.0 in redis_query.py - Add type annotations and improve docstrings for helper methods - Remove stale TODO comments and improve code documentation - Remove duplicate _parse_list_field method in RedisVectorStoreAdapter - Clean up comment formatting and remove unnecessary complexity 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 6c88daf commit 83d5abb

File tree

2 files changed

+76
-103
lines changed

2 files changed

+76
-103
lines changed

agent_memory_server/utils/redis_query.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
from redisvl.query import AggregationQuery, RangeQuery, VectorQuery
66

7+
# Import constants from long_term_memory module
8+
from agent_memory_server.long_term_memory import SECONDS_PER_DAY
9+
710

811
class RecencyAggregationQuery(AggregationQuery):
912
"""AggregationQuery helper for KNN + recency boosting with APPLY/SORTBY and paging.
@@ -64,8 +67,12 @@ def apply_recency(
6467
half_life_access = float(params.get("half_life_last_access_days", 7.0))
6568
half_life_created = float(params.get("half_life_created_days", 30.0))
6669

67-
self.apply(days_since_access=f"max(0, ({now_ts} - @last_accessed)/86400.0)")
68-
self.apply(days_since_created=f"max(0, ({now_ts} - @created_at)/86400.0)")
70+
self.apply(
71+
days_since_access=f"max(0, ({now_ts} - @last_accessed)/{SECONDS_PER_DAY})"
72+
)
73+
self.apply(
74+
days_since_created=f"max(0, ({now_ts} - @created_at)/{SECONDS_PER_DAY})"
75+
)
6976
self.apply(freshness=f"pow(2, -@days_since_access/{half_life_access})")
7077
self.apply(novelty=f"pow(2, -@days_since_created/{half_life_created})")
7178
self.apply(recency=f"{freshness_weight}*@freshness+{novelty_weight}*@novelty")

agent_memory_server/vectorstore_adapter.py

Lines changed: 67 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def convert_filters_to_backend_format(
131131
"""Convert filter objects to backend format for LangChain vectorstores."""
132132
filter_dict: dict[str, Any] = {}
133133

134-
# TODO: Seems like we could take *args filters and decide what to do based on type.
135134
# Apply tag/string filters using the helper function
136135
self.process_tag_filter(session_id, "session_id", filter_dict)
137136
self.process_tag_filter(user_id, "user_id", filter_dict)
@@ -260,11 +259,17 @@ async def count_memories(
260259
"""
261260
pass
262261

263-
def _parse_list_field(self, field_value):
262+
def _parse_list_field(self, field_value: Any) -> list[str]:
264263
"""Parse a field that might be a list, comma-separated string, or None.
265264
266265
Centralized here so both LangChain and Redis adapters can normalize
267266
metadata fields like topics/entities/extracted_from.
267+
268+
Args:
269+
field_value: Value that may be a list, string, or None
270+
271+
Returns:
272+
List of strings, empty list if field_value is falsy
268273
"""
269274
if not field_value:
270275
return []
@@ -414,6 +419,56 @@ def generate_memory_hash(self, memory: MemoryRecord) -> str:
414419

415420
return generate_memory_hash(memory)
416421

422+
def _apply_client_side_recency_reranking(
423+
self, memory_results: list[MemoryRecordResult], recency_params: dict | None
424+
) -> list[MemoryRecordResult]:
425+
"""Apply client-side recency reranking as a fallback when server-side is not available.
426+
427+
Args:
428+
memory_results: List of memory results to rerank
429+
recency_params: Parameters for recency scoring
430+
431+
Returns:
432+
Reranked list of memory results
433+
"""
434+
if not memory_results:
435+
return memory_results
436+
437+
try:
438+
from datetime import UTC as _UTC, datetime as _dt
439+
440+
from agent_memory_server.long_term_memory import rerank_with_recency
441+
442+
now = _dt.now(_UTC)
443+
params = {
444+
"semantic_weight": float(recency_params.get("semantic_weight", 0.8))
445+
if recency_params
446+
else 0.8,
447+
"recency_weight": float(recency_params.get("recency_weight", 0.2))
448+
if recency_params
449+
else 0.2,
450+
"freshness_weight": float(recency_params.get("freshness_weight", 0.6))
451+
if recency_params
452+
else 0.6,
453+
"novelty_weight": float(recency_params.get("novelty_weight", 0.4))
454+
if recency_params
455+
else 0.4,
456+
"half_life_last_access_days": float(
457+
recency_params.get("half_life_last_access_days", 7.0)
458+
)
459+
if recency_params
460+
else 7.0,
461+
"half_life_created_days": float(
462+
recency_params.get("half_life_created_days", 30.0)
463+
)
464+
if recency_params
465+
else 30.0,
466+
}
467+
return rerank_with_recency(memory_results, now=now, params=params)
468+
except Exception as e:
469+
logger.warning(f"Client-side recency reranking failed: {e}")
470+
return memory_results
471+
417472
def _convert_filters_to_backend_format(
418473
self,
419474
session_id: SessionId | None = None,
@@ -445,7 +500,6 @@ def _convert_filters_to_backend_format(
445500
Dictionary filter in format: {"field": {"$eq": "value"}} or None
446501
"""
447502
processor = LangChainFilterProcessor(self.vectorstore)
448-
# TODO: Seems like we could take *args and pass them to the processor
449503
filter_dict = processor.convert_filters_to_backend_format(
450504
session_id=session_id,
451505
user_id=user_id,
@@ -585,50 +639,10 @@ async def search_memories(
585639
memory_results.append(memory_result)
586640

587641
# If recency requested but backend does not support DB-level, rerank here as a fallback
588-
if server_side_recency and memory_results:
589-
try:
590-
from datetime import UTC as _UTC, datetime as _dt
591-
592-
from agent_memory_server.long_term_memory import rerank_with_recency
593-
594-
now = _dt.now(_UTC)
595-
params = {
596-
"semantic_weight": float(
597-
recency_params.get("semantic_weight", 0.8)
598-
)
599-
if recency_params
600-
else 0.8,
601-
"recency_weight": float(
602-
recency_params.get("recency_weight", 0.2)
603-
)
604-
if recency_params
605-
else 0.2,
606-
"freshness_weight": float(
607-
recency_params.get("freshness_weight", 0.6)
608-
)
609-
if recency_params
610-
else 0.6,
611-
"novelty_weight": float(
612-
recency_params.get("novelty_weight", 0.4)
613-
)
614-
if recency_params
615-
else 0.4,
616-
"half_life_last_access_days": float(
617-
recency_params.get("half_life_last_access_days", 7.0)
618-
)
619-
if recency_params
620-
else 7.0,
621-
"half_life_created_days": float(
622-
recency_params.get("half_life_created_days", 30.0)
623-
)
624-
if recency_params
625-
else 30.0,
626-
}
627-
memory_results = rerank_with_recency(
628-
memory_results, now=now, params=params
629-
)
630-
except Exception:
631-
pass
642+
if server_side_recency:
643+
memory_results = self._apply_client_side_recency_reranking(
644+
memory_results, recency_params
645+
)
632646

633647
# Calculate next offset
634648
next_offset = offset + limit if len(docs_with_scores) > limit else None
@@ -844,7 +858,7 @@ async def update_memories(self, memories: list[MemoryRecord]) -> int:
844858
added = await self.add_memories(memories)
845859
return len(added)
846860

847-
def _get_vectorstore_index(self):
861+
def _get_vectorstore_index(self) -> Any | None:
848862
"""Safely access the underlying RedisVL index from the vectorstore.
849863
850864
Returns:
@@ -1066,8 +1080,7 @@ async def search_memories(
10661080
# Convert results to MemoryRecordResult objects
10671081
memory_results = []
10681082
for i, (doc, score) in enumerate(search_results):
1069-
# Apply offset - VectorStore doesn't support pagination...
1070-
# TODO: Implement pagination in RedisVectorStore as a kwarg.
1083+
# Apply offset - VectorStore doesn't support native pagination
10711084
if i < offset:
10721085
continue
10731086

@@ -1120,48 +1133,11 @@ def parse_timestamp_to_datetime(timestamp_val):
11201133
if len(memory_results) >= limit:
11211134
break
11221135

1123-
# Optional server-side recency-aware rerank (adapter-level fallback)
1124-
# If requested, re-rank using the same logic as server API's local reranking.
1136+
# Optional client-side recency-aware rerank (adapter-level fallback)
11251137
if server_side_recency:
1126-
try:
1127-
from datetime import UTC as _UTC, datetime as _dt
1128-
1129-
from agent_memory_server.long_term_memory import rerank_with_recency
1130-
1131-
now = _dt.now(_UTC)
1132-
params = {
1133-
"semantic_weight": float(recency_params.get("semantic_weight", 0.8))
1134-
if recency_params
1135-
else 0.8,
1136-
"recency_weight": float(recency_params.get("recency_weight", 0.2))
1137-
if recency_params
1138-
else 0.2,
1139-
"freshness_weight": float(
1140-
recency_params.get("freshness_weight", 0.6)
1141-
)
1142-
if recency_params
1143-
else 0.6,
1144-
"novelty_weight": float(recency_params.get("novelty_weight", 0.4))
1145-
if recency_params
1146-
else 0.4,
1147-
"half_life_last_access_days": float(
1148-
recency_params.get("half_life_last_access_days", 7.0)
1149-
)
1150-
if recency_params
1151-
else 7.0,
1152-
"half_life_created_days": float(
1153-
recency_params.get("half_life_created_days", 30.0)
1154-
)
1155-
if recency_params
1156-
else 30.0,
1157-
}
1158-
memory_results = rerank_with_recency(
1159-
memory_results, now=now, params=params
1160-
)
1161-
except Exception as e:
1162-
logger.warning(
1163-
f"server_side_recency fallback rerank failed, returning base order: {e}"
1164-
)
1138+
memory_results = self._apply_client_side_recency_reranking(
1139+
memory_results, recency_params
1140+
)
11651141

11661142
next_offset = offset + limit if len(search_results) > offset + limit else None
11671143

@@ -1171,16 +1147,6 @@ def parse_timestamp_to_datetime(timestamp_val):
11711147
next_offset=next_offset,
11721148
)
11731149

1174-
def _parse_list_field(self, field_value):
1175-
"""Parse a field that might be a list, comma-separated string, or None."""
1176-
if not field_value:
1177-
return []
1178-
if isinstance(field_value, list):
1179-
return field_value
1180-
if isinstance(field_value, str):
1181-
return field_value.split(",") if field_value else []
1182-
return []
1183-
11841150
async def delete_memories(self, memory_ids: list[str]) -> int:
11851151
"""Delete memories by their IDs using LangChain's RedisVectorStore."""
11861152
if not memory_ids:

0 commit comments

Comments
 (0)