Skip to content

Commit e447288

Browse files
committed
feat(redis): AggregateQuery with KNN + APPLY + SORTBY for server_side_recency; formatting fixes
1 parent a6d1961 commit e447288

File tree

1 file changed

+85
-99
lines changed

1 file changed

+85
-99
lines changed

agent_memory_server/vectorstore_adapter.py

Lines changed: 85 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -885,76 +885,109 @@ async def search_memories(
885885
# If server-side recency is requested, attempt RedisVL query first (DB-level path)
886886
if server_side_recency:
887887
try:
888-
from redisvl.query import RangeQuery, VectorQuery
888+
from datetime import UTC as _UTC, datetime as _dt
889+
890+
from redisvl.query import AggregateQuery, RangeQuery, VectorQuery
889891

890892
index = getattr(self.vectorstore, "_index", None)
891893
if index is not None:
892894
# Embed the query text to vector
893895
embedding_vector = self.embeddings.embed_query(query)
894896

895-
# Collect fields we need back from Redis
896-
return_fields = [
897-
"id_",
898-
"session_id",
899-
"user_id",
900-
"namespace",
901-
"created_at",
902-
"last_accessed",
903-
"updated_at",
904-
"pinned",
905-
"access_count",
906-
"topics",
907-
"entities",
908-
"memory_hash",
909-
"discrete_memory_extracted",
910-
"memory_type",
911-
"persisted_at",
912-
"extracted_from",
913-
"event_date",
914-
"text",
915-
]
916-
897+
# Build base KNN query (hybrid)
917898
if distance_threshold is not None:
918-
vq = RangeQuery(
899+
knn = RangeQuery(
919900
vector=embedding_vector,
920901
vector_field_name="vector",
921-
return_fields=return_fields,
922902
filter_expression=redis_filter,
923903
distance_threshold=float(distance_threshold),
924904
k=limit,
925905
)
926906
else:
927-
vq = VectorQuery(
907+
knn = VectorQuery(
928908
vector=embedding_vector,
929909
vector_field_name="vector",
930-
return_fields=return_fields,
931910
filter_expression=redis_filter,
932911
k=limit,
933912
)
934913

935-
# Apply RedisVL paging instead of manual slicing
936-
from contextlib import suppress
937-
938-
with suppress(Exception):
939-
vq.paging(offset, limit)
914+
# Aggregate with APPLY/SORTBY boosted score
915+
agg = AggregateQuery(knn.query, filter_expression=redis_filter)
916+
agg.load(
917+
[
918+
"id_",
919+
"session_id",
920+
"user_id",
921+
"namespace",
922+
"created_at",
923+
"last_accessed",
924+
"updated_at",
925+
"pinned",
926+
"access_count",
927+
"topics",
928+
"entities",
929+
"memory_hash",
930+
"discrete_memory_extracted",
931+
"memory_type",
932+
"persisted_at",
933+
"extracted_from",
934+
"event_date",
935+
"text",
936+
"__vector_score",
937+
]
938+
)
940939

941-
# Execute via AsyncSearchIndex if available
942-
if hasattr(index, "asearch"):
943-
raw = await index.asearch(vq)
944-
else:
945-
raw = index.search(vq) # type: ignore
940+
now_ts = int(_dt.now(_UTC).timestamp())
941+
w_sem = (
942+
float(recency_params.get("w_sem", 0.8))
943+
if recency_params
944+
else 0.8
945+
)
946+
w_rec = (
947+
float(recency_params.get("w_recency", 0.2))
948+
if recency_params
949+
else 0.2
950+
)
951+
wf = float(recency_params.get("wf", 0.6)) if recency_params else 0.6
952+
wa = float(recency_params.get("wa", 0.4)) if recency_params else 0.4
953+
hl_la = (
954+
float(recency_params.get("half_life_last_access_days", 7.0))
955+
if recency_params
956+
else 7.0
957+
)
958+
hl_cr = (
959+
float(recency_params.get("half_life_created_days", 30.0))
960+
if recency_params
961+
else 30.0
962+
)
946963

947-
# raw.docs is a list of documents with .fields; handle both dict and attrs
948-
docs = getattr(raw, "docs", raw) or []
964+
agg.apply(
965+
f"max(0, ({now_ts} - @last_accessed)/86400.0)",
966+
AS="days_since_access",
967+
).apply(
968+
f"max(0, ({now_ts} - @created_at)/86400.0)",
969+
AS="days_since_created",
970+
).apply(
971+
f"pow(2, -@days_since_access/{hl_la})", AS="freshness"
972+
).apply(
973+
f"pow(2, -@days_since_created/{hl_cr})", AS="novelty"
974+
).apply(f"{wf}*@freshness+{wa}*@novelty", AS="recency").apply(
975+
"1-(@__vector_score/2)", AS="sim"
976+
).apply(f"{w_sem}*@sim+{w_rec}*@recency", AS="boosted_score")
977+
978+
agg.sort_by([("boosted_score", "DESC")])
979+
agg.limit(offset, limit)
980+
981+
raw = (
982+
await index.aaggregate(agg)
983+
if hasattr(index, "aaggregate")
984+
else index.aggregate(agg) # type: ignore
985+
)
949986

987+
rows = getattr(raw, "rows", raw) or []
950988
memory_results: list[MemoryRecordResult] = []
951-
for doc in docs:
952-
fields = (
953-
getattr(doc, "fields", None)
954-
or getattr(doc, "__dict__", {})
955-
or doc
956-
)
957-
# Build a Document-like structure
989+
for row in rows:
990+
fields = getattr(row, "__dict__", None) or row
958991
metadata = {
959992
k: fields.get(k)
960993
for k in [
@@ -979,65 +1012,18 @@ async def search_memories(
9791012
if k in fields
9801013
}
9811014
text_val = fields.get("text", "")
982-
score = fields.get("__vector_score", None)
983-
if score is None:
984-
# Fallback: assume perfect relevance if score missing
985-
score = 1.0
986-
# Convert to Document and then to MemoryRecordResult using helper
1015+
score = fields.get("__vector_score", 1.0) or 1.0
9871016
doc_obj = Document(page_content=text_val, metadata=metadata)
9881017
memory_results.append(
9891018
self.document_to_memory(doc_obj, float(score))
9901019
)
991-
if len(memory_results) >= limit:
992-
break
993-
994-
# Adapter-level recency rerank for consistency
995-
if memory_results:
996-
try:
997-
from datetime import UTC as _UTC, datetime as _dt
998-
999-
from agent_memory_server.long_term_memory import (
1000-
rerank_with_recency,
1001-
)
1002-
1003-
now = _dt.now(_UTC)
1004-
params = {
1005-
"w_sem": float(recency_params.get("w_sem", 0.8))
1006-
if recency_params
1007-
else 0.8,
1008-
"w_recency": float(recency_params.get("w_recency", 0.2))
1009-
if recency_params
1010-
else 0.2,
1011-
"wf": float(recency_params.get("wf", 0.6))
1012-
if recency_params
1013-
else 0.6,
1014-
"wa": float(recency_params.get("wa", 0.4))
1015-
if recency_params
1016-
else 0.4,
1017-
"half_life_last_access_days": float(
1018-
recency_params.get(
1019-
"half_life_last_access_days", 7.0
1020-
)
1021-
)
1022-
if recency_params
1023-
else 7.0,
1024-
"half_life_created_days": float(
1025-
recency_params.get("half_life_created_days", 30.0)
1026-
)
1027-
if recency_params
1028-
else 30.0,
1029-
}
1030-
memory_results = rerank_with_recency(
1031-
memory_results, now=now, params=params
1032-
)
1033-
except Exception:
1034-
pass
1035-
1036-
total_docs = len(docs) if docs else 0
1037-
next_offset = offset + limit if total_docs == limit else None
1020+
1021+
next_offset = (
1022+
offset + limit if len(memory_results) == limit else None
1023+
)
10381024
return MemoryRecordResults(
10391025
memories=memory_results[:limit],
1040-
total=offset + total_docs,
1026+
total=offset + len(memory_results),
10411027
next_offset=next_offset,
10421028
)
10431029
except Exception as e:

0 commit comments

Comments
 (0)