Skip to content

Commit 4b657a0

Browse files
committed
Batch indexing, fix tests
1 parent dc805ec commit 4b657a0

File tree

4 files changed

+24
-14
lines changed

4 files changed

+24
-14
lines changed

agent-memory-client/tests/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,9 @@ async def test_append_messages_to_working_memory(self, enhanced_test_client):
546546
# Check that messages were appended
547547
working_memory_arg = mock_put.call_args[0][1]
548548
assert len(working_memory_arg.messages) == 3
549-
assert working_memory_arg.messages[0]["content"] == "First message"
550-
assert working_memory_arg.messages[1]["content"] == "Second message"
551-
assert working_memory_arg.messages[2]["content"] == "Third message"
549+
assert working_memory_arg.messages[0].content == "First message"
550+
assert working_memory_arg.messages[1].content == "Second message"
551+
assert working_memory_arg.messages[2].content == "Third message"
552552

553553
def test_deep_merge_dicts(self, enhanced_test_client):
554554
"""Test the deep merge dictionary utility method."""

agent_memory_server/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ async def put_working_memory(
327327
)
328328

329329
# Background tasks for long-term memory promotion and indexing (if enabled)
330-
if settings.long_term_memory and updated_memory.memories:
330+
if settings.long_term_memory and (
331+
updated_memory.memories or updated_memory.messages
332+
):
331333
# Promote structured memories from working memory to long-term storage
332334
await background_tasks.add_task(
333335
long_term_memory.promote_working_memory_to_long_term,

agent_memory_server/long_term_memory.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,7 @@ async def promote_working_memory_to_long_term(
13261326

13271327
# Process unpersisted messages
13281328
updated_messages = []
1329+
memory_records_to_index = []
13291330
for msg in current_working_memory.messages:
13301331
if msg.persisted_at is None:
13311332
# Generate ID if not present (backward compatibility)
@@ -1352,12 +1353,8 @@ async def promote_working_memory_to_long_term(
13521353
current_memory = deduped_memory or memory_record
13531354
current_memory.persisted_at = datetime.now(UTC)
13541355

1355-
# Index in long-term storage
1356-
await index_long_term_memories(
1357-
[current_memory],
1358-
redis_client=redis,
1359-
deduplicate=False, # Already deduplicated by ID
1360-
)
1356+
# Collect memory record for batch indexing
1357+
memory_records_to_index.append(current_memory)
13611358

13621359
# Update message with persisted_at timestamp
13631360
msg.persisted_at = current_memory.persisted_at
@@ -1370,6 +1367,14 @@ async def promote_working_memory_to_long_term(
13701367

13711368
updated_messages.append(msg)
13721369

1370+
# Batch index all new memory records for messages
1371+
if memory_records_to_index:
1372+
await index_long_term_memories(
1373+
memory_records_to_index,
1374+
redis_client=redis,
1375+
deduplicate=False, # Already deduplicated by ID
1376+
)
1377+
13731378
# Update working memory with the new persisted_at timestamps and extracted memories
13741379
if promoted_count > 0 or extracted_memories:
13751380
updated_working_memory = current_working_memory.model_copy()

tests/test_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from agent_memory_server.config import Settings
88
from agent_memory_server.long_term_memory import (
9-
index_long_term_memories,
109
promote_working_memory_to_long_term,
1110
)
1211
from agent_memory_server.models import (
@@ -153,7 +152,11 @@ async def test_put_memory(self, client):
153152
"/v1/working-memory/test-session?namespace=test-namespace"
154153
)
155154
assert updated_session.status_code == 200
156-
assert updated_session.json()["messages"] == payload["messages"]
155+
retrieved_messages = updated_session.json()["messages"]
156+
assert len(retrieved_messages) == len(payload["messages"])
157+
for i, msg in enumerate(retrieved_messages):
158+
assert msg["role"] == payload["messages"][i]["role"]
159+
assert msg["content"] == payload["messages"][i]["content"]
157160

158161
@pytest.mark.requires_api_keys
159162
@pytest.mark.asyncio
@@ -188,10 +191,10 @@ async def test_put_memory_stores_messages_in_long_term_memory(
188191
# Check that background tasks were called
189192
assert mock_background_tasks.add_task.call_count == 1
190193

191-
# Check that the last call was for long-term memory indexing
194+
# Check that the last call was for long-term memory promotion
192195
assert (
193196
mock_background_tasks.add_task.call_args_list[-1][0][0]
194-
== index_long_term_memories
197+
== promote_working_memory_to_long_term
195198
)
196199

197200
@pytest.mark.requires_api_keys

0 commit comments

Comments
 (0)