Skip to content

Commit 3126c08

Browse files
committed
Clean up the vectore store init options
1 parent db4ac50 commit 3126c08

File tree

5 files changed

+44
-13
lines changed

5 files changed

+44
-13
lines changed

agent_memory_server/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ class Settings(BaseSettings):
7878

7979
# Topic modeling
8080
topic_model_source: Literal["BERTopic", "LLM"] = "LLM"
81-
topic_model: str = (
82-
"MaartenGr/BERTopic_Wikipedia" # Use an LLM model name here if using LLM
83-
)
81+
topic_model: str = "gpt-4o-mini"
8482
enable_topic_extraction: bool = True
8583
top_k_topics: int = 3
8684

@@ -89,9 +87,11 @@ class Settings(BaseSettings):
8987
enable_ner: bool = True
9088

9189
# RedisVL Settings
90+
# TODO: Adapt to vector store settings
9291
redisvl_distance_metric: str = "COSINE"
9392
redisvl_vector_dimensions: str = "1536"
9493
redisvl_index_prefix: str = "memory_idx"
94+
redisvl_indexing_algorithm: str = "HNSW"
9595

9696
# Docket settings
9797
docket_name: str = "memory-server"

agent_memory_server/extraction.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import json
22
import os
3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
import ulid
6-
from bertopic import BERTopic
76
from redis.asyncio.client import Redis
87
from tenacity.asyncio import AsyncRetrying
98
from tenacity.stop import stop_after_attempt
@@ -22,24 +21,30 @@
2221
from agent_memory_server.utils.redis import get_redis_conn
2322

2423

24+
if TYPE_CHECKING:
25+
from bertopic import BERTopic
26+
27+
2528
logger = get_logger(__name__)
2629

2730
# Set tokenizer parallelism environment variable
2831
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2932

3033
# Global model instances
31-
_topic_model: BERTopic | None = None
34+
_topic_model: "BERTopic | None" = None
3235
_ner_model: Any | None = None
3336
_ner_tokenizer: Any | None = None
3437

3538

36-
def get_topic_model() -> BERTopic:
39+
def get_topic_model() -> "BERTopic":
3740
"""
3841
Get or initialize the BERTopic model.
3942
4043
Returns:
4144
The BERTopic model instance
4245
"""
46+
from bertopic import BERTopic
47+
4348
global _topic_model
4449
if _topic_model is None:
4550
# TODO: Expose this as a config option
@@ -112,7 +117,7 @@ async def extract_topics_llm(
112117
"""
113118
Extract topics from text using the LLM model.
114119
"""
115-
_client = client or await get_model_client(settings.generation_model)
120+
_client = client or await get_model_client(settings.topic_model)
116121
_num_topics = num_topics if num_topics is not None else settings.top_k_topics
117122

118123
prompt = f"""

agent_memory_server/long_term_memory.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,19 @@ async def merge_memories_with_llm(memories: list[dict], llm_client: Any = None)
208208
# Fallback if the structure is different
209209
merged_text = str(response.choices[0])
210210

211+
def float_or_datetime(m: dict, key: str) -> float:
212+
val = m.get(key, time.time())
213+
if val is None:
214+
return time.time()
215+
if isinstance(val, datetime):
216+
return int(val.timestamp())
217+
return float(val)
218+
211219
# Use the earliest creation timestamp
212-
created_at = min(int(m.get("created_at", int(time.time()))) for m in memories)
220+
created_at = min(float_or_datetime(m, "created_at") for m in memories)
213221

214222
# Use the most recent last_accessed timestamp
215-
last_accessed = max(int(m.get("last_accessed", int(time.time()))) for m in memories)
223+
last_accessed = max(float_or_datetime(m, "last_accessed") for m in memories)
216224

217225
# Prefer non-empty namespace, user_id, session_id from memories
218226
namespace = next((m["namespace"] for m in memories if m.get("namespace")), None)
@@ -616,6 +624,7 @@ async def index_long_term_memories(
616624

617625
# Add the memory to be indexed if not a pure duplicate
618626
if not was_deduplicated:
627+
current_memory.discrete_memory_extracted = "t"
619628
processed_memories.append(current_memory)
620629
else:
621630
processed_memories = memories

agent_memory_server/vectorstore_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,12 @@ def convert_filters_to_backend_format(
123123
last_accessed: LastAccessed | None = None,
124124
event_date: EventDate | None = None,
125125
memory_hash: MemoryHash | None = None,
126+
discrete_memory_extracted: DiscreteMemoryExtracted | None = None,
126127
) -> dict[str, Any] | None:
127128
"""Convert filter objects to backend format for LangChain vectorstores."""
128129
filter_dict: dict[str, Any] = {}
129130

131+
# TODO: Seems like we could take *args filters and decide what to do based on type.
130132
# Apply tag/string filters using the helper function
131133
self.process_tag_filter(session_id, "session_id", filter_dict)
132134
self.process_tag_filter(user_id, "user_id", filter_dict)
@@ -135,6 +137,9 @@ def convert_filters_to_backend_format(
135137
self.process_tag_filter(topics, "topics", filter_dict)
136138
self.process_tag_filter(entities, "entities", filter_dict)
137139
self.process_tag_filter(memory_hash, "memory_hash", filter_dict)
140+
self.process_tag_filter(
141+
discrete_memory_extracted, "discrete_memory_extracted", filter_dict
142+
)
138143

139144
# Apply datetime filters using the helper function (uses instance method for backend-specific formatting)
140145
self.process_datetime_filter(created_at, "created_at", filter_dict)
@@ -374,6 +379,7 @@ def _convert_filters_to_backend_format(
374379
last_accessed: LastAccessed | None = None,
375380
event_date: EventDate | None = None,
376381
memory_hash: MemoryHash | None = None,
382+
discrete_memory_extracted: DiscreteMemoryExtracted | None = None,
377383
) -> dict[str, Any] | None:
378384
"""Convert filter objects to standard LangChain dictionary format.
379385
@@ -391,6 +397,7 @@ def _convert_filters_to_backend_format(
391397
Dictionary filter in format: {"field": {"$eq": "value"}} or None
392398
"""
393399
processor = LangChainFilterProcessor(self.vectorstore)
400+
# TODO: Seems like we could take *args and pass them to the processor
394401
filter_dict = processor.convert_filters_to_backend_format(
395402
session_id=session_id,
396403
user_id=user_id,
@@ -489,6 +496,7 @@ async def search_memories(
489496
last_accessed=last_accessed,
490497
event_date=event_date,
491498
memory_hash=memory_hash,
499+
discrete_memory_extracted=discrete_memory_extracted,
492500
)
493501

494502
# Use LangChain's similarity search with filters
@@ -497,6 +505,8 @@ async def search_memories(
497505
search_kwargs["filter"] = filter_dict
498506

499507
# Perform similarity search
508+
logger.info(f"Searching for memories with filters: {search_kwargs}")
509+
500510
docs_with_scores = (
501511
await self.vectorstore.asimilarity_search_with_relevance_scores(
502512
query, **search_kwargs

agent_memory_server/vectorstore_factory.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from langchain_core.embeddings import Embeddings
2525
from langchain_core.vectorstores import VectorStore
26+
from langchain_redis.config import RedisConfig
2627
from pydantic.types import SecretStr
2728

2829

@@ -207,9 +208,15 @@ def create_redis_vectorstore(embeddings: Embeddings) -> VectorStore:
207208
# Always use MemoryRedisVectorStore for consistency and to fix relevance score issues
208209
return MemoryRedisVectorStore(
209210
embeddings=embeddings,
210-
redis_url=settings.redis_url,
211-
index_name=settings.redisvl_index_name,
212-
metadata_schema=metadata_schema,
211+
config=RedisConfig(
212+
redis_url=settings.redis_url,
213+
key_prefix=settings.redisvl_index_prefix,
214+
indexing_algorithm=settings.redisvl_indexing_algorithm,
215+
index_name=settings.redisvl_index_name,
216+
metadata_schema=metadata_schema,
217+
distance_metric=settings.redisvl_distance_metric,
218+
embedding_dimensions=int(settings.redisvl_vector_dimensions),
219+
),
213220
)
214221
except ImportError:
215222
logger.error(

0 commit comments

Comments
 (0)