Skip to content

Commit 5849f71

Browse files
committed
Remove AsyncSearchIndex usage
1 parent 61b5123 commit 5849f71

File tree

8 files changed

+197
-269
lines changed

8 files changed

+197
-269
lines changed

agent_memory_server/extraction.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
import ulid
66
from bertopic import BERTopic
77
from redis.asyncio.client import Redis
8-
from redisvl.query.filter import Tag
9-
from redisvl.query.query import FilterQuery
108
from tenacity.asyncio import AsyncRetrying
119
from tenacity.stop import stop_after_attempt
1210
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
1311

1412
from agent_memory_server.config import settings
13+
from agent_memory_server.filters import DiscreteMemoryExtracted
1514
from agent_memory_server.llms import (
1615
AnthropicClientWrapper,
1716
OpenAIClientWrapper,
1817
get_model_client,
1918
)
2019
from agent_memory_server.logging import get_logger
2120
from agent_memory_server.models import MemoryRecord
22-
from agent_memory_server.utils.redis import get_redis_conn, get_search_index
21+
from agent_memory_server.utils.keys import Keys
22+
from agent_memory_server.utils.redis import get_redis_conn
2323

2424

2525
logger = get_logger(__name__)
@@ -269,25 +269,32 @@ async def extract_discrete_memories(
269269
"""
270270
redis = await get_redis_conn()
271271
client = await get_model_client(settings.generation_model)
272-
query = FilterQuery(
273-
filter_expression=(Tag("discrete_memory_extracted") == "f")
274-
& (Tag("memory_type") == "message")
275-
)
272+
273+
# Use vectorstore adapter to find messages that need discrete memory extraction
274+
from agent_memory_server.filters import MemoryType
275+
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
276+
277+
adapter = await get_vectorstore_adapter()
276278
offset = 0
277279

278280
while True:
279-
query.paging(num=25, offset=offset)
280-
search_index = get_search_index(redis=redis)
281-
messages = await search_index.query(query)
281+
# Search for message-type memories that haven't been processed for discrete extraction
282+
search_result = await adapter.search_memories(
283+
query="", # Empty query to get all messages
284+
memory_type=MemoryType(eq="message"),
285+
discrete_memory_extracted=DiscreteMemoryExtracted(ne="t"),
286+
limit=25,
287+
offset=offset,
288+
)
289+
282290
discrete_memories = []
283291

284-
for message in messages:
285-
if not message or not message.get("text"):
292+
for message in search_result.memories:
293+
if not message or not message.text:
286294
logger.info(f"Deleting memory with no text: {message}")
287-
await redis.delete(message["id"])
295+
await adapter.delete_memories([message.id])
288296
continue
289-
id_ = message.get("id_")
290-
if not id_:
297+
if not message.id:
291298
logger.error(f"Skipping memory with no ID: {message}")
292299
continue
293300

@@ -296,7 +303,7 @@ async def extract_discrete_memories(
296303
response = await client.create_chat_completion(
297304
model=settings.generation_model,
298305
prompt=DISCRETE_EXTRACTION_PROMPT.format(
299-
message=message["text"], top_k_topics=settings.top_k_topics
306+
message=message.text, top_k_topics=settings.top_k_topics
300307
),
301308
response_format={"type": "json_object"},
302309
)
@@ -317,13 +324,15 @@ async def extract_discrete_memories(
317324
raise
318325
discrete_memories.extend(new_message["memories"])
319326

327+
# Update the memory to mark it as processed
328+
# For now, we need to use Redis directly as the adapter doesn't have an update method
320329
await redis.hset(
321-
name=message["id"],
330+
name=Keys.memory_key(message.id), # Construct the key
322331
key="discrete_memory_extracted",
323332
value="t",
324333
) # type: ignore
325334

326-
if len(messages) < 25:
335+
if len(search_result.memories) < 25:
327336
break
328337
offset += 25
329338

@@ -333,7 +342,7 @@ async def extract_discrete_memories(
333342
if discrete_memories:
334343
long_term_memories = [
335344
MemoryRecord(
336-
id_=str(ulid.ULID()),
345+
id=str(ulid.ULID()),
337346
text=new_memory["text"],
338347
memory_type=new_memory.get("type", "episodic"),
339348
topics=new_memory.get("topics", []),

agent_memory_server/filters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,7 @@ class EventDate(DateTimeFilter):
242242

243243
class MemoryHash(TagFilter):
244244
field: str = "memory_hash"
245+
246+
247+
class DiscreteMemoryExtracted(TagFilter):
248+
field: str = "discrete_memory_extracted"

0 commit comments

Comments
 (0)