Skip to content

Commit eac6268

Browse files
committed
refactor(redis): integrate RecencyAggregationQuery; fix AggregationQuery usage; clean imports
1 parent e447288 commit eac6268

File tree

2 files changed

+95
-66
lines changed

2 files changed

+95
-66
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from redisvl.query import AggregationQuery, RangeQuery, VectorQuery
6+
7+
8+
class RecencyAggregationQuery(AggregationQuery):
9+
"""AggregationQuery helper for KNN + recency boosting with APPLY/SORTBY and paging.
10+
11+
Usage:
12+
- Build a VectorQuery or RangeQuery (hybrid filter expression allowed)
13+
- Call RecencyAggregationQuery.from_vector_query(...)
14+
- Chain .load_default_fields().apply_recency(params).sort_by_boosted_desc().paginate(offset, limit)
15+
"""
16+
17+
DEFAULT_RETURN_FIELDS = [
18+
"id_",
19+
"session_id",
20+
"user_id",
21+
"namespace",
22+
"created_at",
23+
"last_accessed",
24+
"updated_at",
25+
"pinned",
26+
"access_count",
27+
"topics",
28+
"entities",
29+
"memory_hash",
30+
"discrete_memory_extracted",
31+
"memory_type",
32+
"persisted_at",
33+
"extracted_from",
34+
"event_date",
35+
"text",
36+
"__vector_score",
37+
]
38+
39+
@classmethod
40+
def from_vector_query(
41+
cls,
42+
vq: VectorQuery | RangeQuery,
43+
*,
44+
filter_expression: Any | None = None,
45+
) -> RecencyAggregationQuery:
46+
return cls(vq.query, filter_expression=filter_expression)
47+
48+
def load_default_fields(self) -> RecencyAggregationQuery:
49+
self.load(self.DEFAULT_RETURN_FIELDS)
50+
return self
51+
52+
def apply_recency(
53+
self, *, now_ts: int, params: dict[str, Any] | None = None
54+
) -> RecencyAggregationQuery:
55+
params = params or {}
56+
w_sem = float(params.get("w_sem", 0.8))
57+
w_rec = float(params.get("w_recency", 0.2))
58+
wf = float(params.get("wf", 0.6))
59+
wa = float(params.get("wa", 0.4))
60+
hl_la = float(params.get("half_life_last_access_days", 7.0))
61+
hl_cr = float(params.get("half_life_created_days", 30.0))
62+
63+
self.apply(
64+
f"max(0, ({now_ts} - @last_accessed)/86400.0)", AS="days_since_access"
65+
).apply(
66+
f"max(0, ({now_ts} - @created_at)/86400.0)", AS="days_since_created"
67+
).apply(f"pow(2, -@days_since_access/{hl_la})", AS="freshness").apply(
68+
f"pow(2, -@days_since_created/{hl_cr})", AS="novelty"
69+
).apply(f"{wf}*@freshness+{wa}*@novelty", AS="recency").apply(
70+
"1-(@__vector_score/2)", AS="sim"
71+
).apply(f"{w_sem}*@sim+{w_rec}*@recency", AS="boosted_score")
72+
73+
return self
74+
75+
def sort_by_boosted_desc(self) -> RecencyAggregationQuery:
76+
self.sort_by([("boosted_score", "DESC")])
77+
return self
78+
79+
def paginate(self, offset: int, limit: int) -> RecencyAggregationQuery:
80+
self.limit(offset, limit)
81+
return self

agent_memory_server/vectorstore_adapter.py

Lines changed: 14 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain_core.embeddings import Embeddings
1414
from langchain_core.vectorstores import VectorStore
1515
from langchain_redis.vectorstores import RedisVectorStore
16+
from redisvl.query import RangeQuery, VectorQuery
1617

1718
from agent_memory_server.filters import (
1819
CreatedAt,
@@ -885,10 +886,6 @@ async def search_memories(
885886
# If server-side recency is requested, attempt RedisVL query first (DB-level path)
886887
if server_side_recency:
887888
try:
888-
from datetime import UTC as _UTC, datetime as _dt
889-
890-
from redisvl.query import AggregateQuery, RangeQuery, VectorQuery
891-
892889
index = getattr(self.vectorstore, "_index", None)
893890
if index is not None:
894891
# Embed the query text to vector
@@ -911,73 +908,24 @@ async def search_memories(
911908
k=limit,
912909
)
913910

914-
# Aggregate with APPLY/SORTBY boosted score
915-
agg = AggregateQuery(knn.query, filter_expression=redis_filter)
916-
agg.load(
917-
[
918-
"id_",
919-
"session_id",
920-
"user_id",
921-
"namespace",
922-
"created_at",
923-
"last_accessed",
924-
"updated_at",
925-
"pinned",
926-
"access_count",
927-
"topics",
928-
"entities",
929-
"memory_hash",
930-
"discrete_memory_extracted",
931-
"memory_type",
932-
"persisted_at",
933-
"extracted_from",
934-
"event_date",
935-
"text",
936-
"__vector_score",
937-
]
911+
# Aggregate with APPLY/SORTBY boosted score via helper
912+
from datetime import UTC as _UTC, datetime as _dt
913+
914+
from agent_memory_server.utils.redis_query import (
915+
RecencyAggregationQuery,
938916
)
939917

940918
now_ts = int(_dt.now(_UTC).timestamp())
941-
w_sem = (
942-
float(recency_params.get("w_sem", 0.8))
943-
if recency_params
944-
else 0.8
945-
)
946-
w_rec = (
947-
float(recency_params.get("w_recency", 0.2))
948-
if recency_params
949-
else 0.2
950-
)
951-
wf = float(recency_params.get("wf", 0.6)) if recency_params else 0.6
952-
wa = float(recency_params.get("wa", 0.4)) if recency_params else 0.4
953-
hl_la = (
954-
float(recency_params.get("half_life_last_access_days", 7.0))
955-
if recency_params
956-
else 7.0
957-
)
958-
hl_cr = (
959-
float(recency_params.get("half_life_created_days", 30.0))
960-
if recency_params
961-
else 30.0
919+
agg = (
920+
RecencyAggregationQuery.from_vector_query(
921+
knn, filter_expression=redis_filter
922+
)
923+
.load_default_fields()
924+
.apply_recency(now_ts=now_ts, params=recency_params or {})
925+
.sort_by_boosted_desc()
926+
.paginate(offset, limit)
962927
)
963928

964-
agg.apply(
965-
f"max(0, ({now_ts} - @last_accessed)/86400.0)",
966-
AS="days_since_access",
967-
).apply(
968-
f"max(0, ({now_ts} - @created_at)/86400.0)",
969-
AS="days_since_created",
970-
).apply(
971-
f"pow(2, -@days_since_access/{hl_la})", AS="freshness"
972-
).apply(
973-
f"pow(2, -@days_since_created/{hl_cr})", AS="novelty"
974-
).apply(f"{wf}*@freshness+{wa}*@novelty", AS="recency").apply(
975-
"1-(@__vector_score/2)", AS="sim"
976-
).apply(f"{w_sem}*@sim+{w_rec}*@recency", AS="boosted_score")
977-
978-
agg.sort_by([("boosted_score", "DESC")])
979-
agg.limit(offset, limit)
980-
981929
raw = (
982930
await index.aaggregate(agg)
983931
if hasattr(index, "aaggregate")

0 commit comments

Comments
 (0)