|
13 | 13 | from langchain_core.embeddings import Embeddings
|
14 | 14 | from langchain_core.vectorstores import VectorStore
|
15 | 15 | from langchain_redis.vectorstores import RedisVectorStore
|
16 |
| -from redisvl.query import RangeQuery, VectorQuery |
17 | 16 |
|
18 | 17 | from agent_memory_server.filters import (
|
19 | 18 | CreatedAt,
|
@@ -837,6 +836,127 @@ async def update_memories(self, memories: list[MemoryRecord]) -> int:
|
837 | 836 | added = await self.add_memories(memories)
|
838 | 837 | return len(added)
|
839 | 838 |
|
| 839 | + def _get_vectorstore_index(self): |
| 840 | + """Safely access the underlying RedisVL index from the vectorstore. |
| 841 | +
|
| 842 | + Returns: |
| 843 | + RedisVL SearchIndex or None if not available |
| 844 | + """ |
| 845 | + return getattr(self.vectorstore, "_index", None) |
| 846 | + |
| 847 | + async def _search_with_redis_aggregation( |
| 848 | + self, |
| 849 | + query: str, |
| 850 | + redis_filter, |
| 851 | + limit: int, |
| 852 | + offset: int, |
| 853 | + distance_threshold: float | None, |
| 854 | + recency_params: dict | None, |
| 855 | + ) -> MemoryRecordResults: |
| 856 | + """Perform server-side Redis aggregation search with recency scoring. |
| 857 | +
|
| 858 | + Args: |
| 859 | + query: Search query text |
| 860 | + redis_filter: Redis filter expression |
| 861 | + limit: Maximum number of results |
| 862 | + offset: Offset for pagination |
| 863 | + distance_threshold: Distance threshold for range queries |
| 864 | + recency_params: Parameters for recency scoring |
| 865 | +
|
| 866 | + Returns: |
| 867 | + MemoryRecordResults with server-side scored results |
| 868 | +
|
| 869 | + Raises: |
| 870 | + Exception: If Redis aggregation fails (caller should handle fallback) |
| 871 | + """ |
| 872 | + from datetime import UTC as _UTC, datetime as _dt |
| 873 | + |
| 874 | + from langchain_core.documents import Document |
| 875 | + from redisvl.query import RangeQuery, VectorQuery |
| 876 | + |
| 877 | + from agent_memory_server.utils.redis_query import RecencyAggregationQuery |
| 878 | + |
| 879 | + index = self._get_vectorstore_index() |
| 880 | + if index is None: |
| 881 | + raise Exception("RedisVL index not available") |
| 882 | + |
| 883 | + # Embed the query text to vector |
| 884 | + embedding_vector = self.embeddings.embed_query(query) |
| 885 | + |
| 886 | + # Build base KNN query (hybrid) |
| 887 | + if distance_threshold is not None: |
| 888 | + knn = RangeQuery( |
| 889 | + vector=embedding_vector, |
| 890 | + vector_field_name="vector", |
| 891 | + filter_expression=redis_filter, |
| 892 | + distance_threshold=float(distance_threshold), |
| 893 | + num_results=limit, |
| 894 | + ) |
| 895 | + else: |
| 896 | + knn = VectorQuery( |
| 897 | + vector=embedding_vector, |
| 898 | + vector_field_name="vector", |
| 899 | + filter_expression=redis_filter, |
| 900 | + num_results=limit, |
| 901 | + ) |
| 902 | + |
| 903 | + # Aggregate with APPLY/SORTBY boosted score via helper |
| 904 | + now_ts = int(_dt.now(_UTC).timestamp()) |
| 905 | + agg = ( |
| 906 | + RecencyAggregationQuery.from_vector_query( |
| 907 | + knn, filter_expression=redis_filter |
| 908 | + ) |
| 909 | + .load_default_fields() |
| 910 | + .apply_recency(now_ts=now_ts, params=recency_params or {}) |
| 911 | + .sort_by_boosted_desc() |
| 912 | + .paginate(offset, limit) |
| 913 | + ) |
| 914 | + |
| 915 | + raw = ( |
| 916 | + await index.aaggregate(agg) |
| 917 | + if hasattr(index, "aaggregate") |
| 918 | + else index.aggregate(agg) # type: ignore |
| 919 | + ) |
| 920 | + |
| 921 | + rows = getattr(raw, "rows", raw) or [] |
| 922 | + memory_results: list[MemoryRecordResult] = [] |
| 923 | + for row in rows: |
| 924 | + fields = getattr(row, "__dict__", None) or row |
| 925 | + metadata = { |
| 926 | + k: fields.get(k) |
| 927 | + for k in [ |
| 928 | + "id_", |
| 929 | + "session_id", |
| 930 | + "user_id", |
| 931 | + "namespace", |
| 932 | + "created_at", |
| 933 | + "last_accessed", |
| 934 | + "updated_at", |
| 935 | + "pinned", |
| 936 | + "access_count", |
| 937 | + "topics", |
| 938 | + "entities", |
| 939 | + "memory_hash", |
| 940 | + "discrete_memory_extracted", |
| 941 | + "memory_type", |
| 942 | + "persisted_at", |
| 943 | + "extracted_from", |
| 944 | + "event_date", |
| 945 | + ] |
| 946 | + if k in fields |
| 947 | + } |
| 948 | + text_val = fields.get("text", "") |
| 949 | + score = fields.get("__vector_score", 1.0) or 1.0 |
| 950 | + doc_obj = Document(page_content=text_val, metadata=metadata) |
| 951 | + memory_results.append(self.document_to_memory(doc_obj, float(score))) |
| 952 | + |
| 953 | + next_offset = offset + limit if len(memory_results) == limit else None |
| 954 | + return MemoryRecordResults( |
| 955 | + memories=memory_results[:limit], |
| 956 | + total=offset + len(memory_results), |
| 957 | + next_offset=next_offset, |
| 958 | + ) |
| 959 | + |
840 | 960 | async def search_memories(
|
841 | 961 | self,
|
842 | 962 | query: str,
|
@@ -900,94 +1020,14 @@ async def search_memories(
|
900 | 1020 | # If server-side recency is requested, attempt RedisVL query first (DB-level path)
|
901 | 1021 | if server_side_recency:
|
902 | 1022 | try:
|
903 |
| - index = getattr(self.vectorstore, "_index", None) |
904 |
| - if index is not None: |
905 |
| - # Embed the query text to vector |
906 |
| - embedding_vector = self.embeddings.embed_query(query) |
907 |
| - |
908 |
| - # Build base KNN query (hybrid) |
909 |
| - if distance_threshold is not None: |
910 |
| - knn = RangeQuery( |
911 |
| - vector=embedding_vector, |
912 |
| - vector_field_name="vector", |
913 |
| - filter_expression=redis_filter, |
914 |
| - distance_threshold=float(distance_threshold), |
915 |
| - num_results=limit, |
916 |
| - ) |
917 |
| - else: |
918 |
| - knn = VectorQuery( |
919 |
| - vector=embedding_vector, |
920 |
| - vector_field_name="vector", |
921 |
| - filter_expression=redis_filter, |
922 |
| - num_results=limit, |
923 |
| - ) |
924 |
| - |
925 |
| - # Aggregate with APPLY/SORTBY boosted score via helper |
926 |
| - from datetime import UTC as _UTC, datetime as _dt |
927 |
| - |
928 |
| - from agent_memory_server.utils.redis_query import ( |
929 |
| - RecencyAggregationQuery, |
930 |
| - ) |
931 |
| - |
932 |
| - now_ts = int(_dt.now(_UTC).timestamp()) |
933 |
| - agg = ( |
934 |
| - RecencyAggregationQuery.from_vector_query( |
935 |
| - knn, filter_expression=redis_filter |
936 |
| - ) |
937 |
| - .load_default_fields() |
938 |
| - .apply_recency(now_ts=now_ts, params=recency_params or {}) |
939 |
| - .sort_by_boosted_desc() |
940 |
| - .paginate(offset, limit) |
941 |
| - ) |
942 |
| - |
943 |
| - raw = ( |
944 |
| - await index.aaggregate(agg) |
945 |
| - if hasattr(index, "aaggregate") |
946 |
| - else index.aggregate(agg) # type: ignore |
947 |
| - ) |
948 |
| - |
949 |
| - rows = getattr(raw, "rows", raw) or [] |
950 |
| - memory_results: list[MemoryRecordResult] = [] |
951 |
| - for row in rows: |
952 |
| - fields = getattr(row, "__dict__", None) or row |
953 |
| - metadata = { |
954 |
| - k: fields.get(k) |
955 |
| - for k in [ |
956 |
| - "id_", |
957 |
| - "session_id", |
958 |
| - "user_id", |
959 |
| - "namespace", |
960 |
| - "created_at", |
961 |
| - "last_accessed", |
962 |
| - "updated_at", |
963 |
| - "pinned", |
964 |
| - "access_count", |
965 |
| - "topics", |
966 |
| - "entities", |
967 |
| - "memory_hash", |
968 |
| - "discrete_memory_extracted", |
969 |
| - "memory_type", |
970 |
| - "persisted_at", |
971 |
| - "extracted_from", |
972 |
| - "event_date", |
973 |
| - ] |
974 |
| - if k in fields |
975 |
| - } |
976 |
| - text_val = fields.get("text", "") |
977 |
| - score = fields.get("__vector_score", 1.0) or 1.0 |
978 |
| - doc_obj = Document(page_content=text_val, metadata=metadata) |
979 |
| - memory_results.append( |
980 |
| - self.document_to_memory(doc_obj, float(score)) |
981 |
| - ) |
982 |
| - |
983 |
| - next_offset = ( |
984 |
| - offset + limit if len(memory_results) == limit else None |
985 |
| - ) |
986 |
| - return MemoryRecordResults( |
987 |
| - memories=memory_results[:limit], |
988 |
| - total=offset + len(memory_results), |
989 |
| - next_offset=next_offset, |
990 |
| - ) |
| 1023 | + return await self._search_with_redis_aggregation( |
| 1024 | + query=query, |
| 1025 | + redis_filter=redis_filter, |
| 1026 | + limit=limit, |
| 1027 | + offset=offset, |
| 1028 | + distance_threshold=distance_threshold, |
| 1029 | + recency_params=recency_params, |
| 1030 | + ) |
991 | 1031 | except Exception as e:
|
992 | 1032 | logger.warning(
|
993 | 1033 | f"RedisVL DB-level recency search failed; falling back to client-side path: {e}"
|
|
0 commit comments