Skip to content

Commit 762fff7

Browse files
committed
Fix deduplication for "message" type long-term memories
The memory server indexes any messages from working memory into "message" type memory records in long-term memory. However, we were not attempting to deduplicate these memory records -- only the higher-level long-term memory types, like semantic and episodic memory. This change refactors the logic we use to store messages long-term such that they work similarly to long-term memories the client adds to working memory. When a client adds a message to working memory, the client gives the memory an ID, and the memory gets a blank `persisted_at` timestamp. After setting working memory, we kick off a background task to persist any promoted long-term memories and messages in long-term memory. When we copy the messages into long-term memory, we now update the `persisted_at` timestamp. If we see that working memory again and try to persist any new memories or messages, we'll skip any that have already been persisted.
1 parent 2647e0a commit 762fff7

File tree

9 files changed

+176
-59
lines changed

9 files changed

+176
-59
lines changed

agent-memory-client/agent_memory_client/client.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AckResponse,
3232
ClientMemoryRecord,
3333
HealthCheckResponse,
34+
MemoryMessage,
3435
MemoryRecord,
3536
MemoryRecordResults,
3637
MemoryTypeEnum,
@@ -2055,7 +2056,7 @@ async def update_working_memory_data(
20552056
async def append_messages_to_working_memory(
20562057
self,
20572058
session_id: str,
2058-
messages: list[dict[str, Any]], # Expect proper message dicts
2059+
messages: list[dict[str, Any] | MemoryMessage],
20592060
namespace: str | None = None,
20602061
model_name: str | None = None,
20612062
context_window_max: int | None = None,
@@ -2068,7 +2069,7 @@ async def append_messages_to_working_memory(
20682069
20692070
Args:
20702071
session_id: Target session
2071-
messages: List of message dictionaries with 'role' and 'content' keys
2072+
messages: List of message dictionaries or MemoryMessage objects
20722073
namespace: Optional namespace
20732074
model_name: Optional model name for token-based summarization
20742075
context_window_max: Optional direct specification of context window max tokens
@@ -2081,19 +2082,36 @@ async def append_messages_to_working_memory(
20812082
session_id=session_id, namespace=namespace, user_id=user_id
20822083
)
20832084

2084-
# Validate new messages have required structure
2085+
# Convert messages to MemoryMessage objects
2086+
converted_messages = []
20852087
for msg in messages:
2086-
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
2088+
if isinstance(msg, MemoryMessage):
2089+
converted_messages.append(msg)
2090+
elif isinstance(msg, dict):
2091+
if "role" not in msg or "content" not in msg:
2092+
raise ValueError("All messages must have 'role' and 'content' keys")
2093+
# Build message kwargs, only including non-None values
2094+
message_kwargs = {
2095+
"role": msg["role"],
2096+
"content": msg["content"],
2097+
}
2098+
if msg.get("id") is not None:
2099+
message_kwargs["id"] = msg["id"]
2100+
if msg.get("persisted_at") is not None:
2101+
message_kwargs["persisted_at"] = msg["persisted_at"]
2102+
2103+
converted_messages.append(MemoryMessage(**message_kwargs))
2104+
else:
20872105
raise ValueError(
2088-
"All messages must be dictionaries with 'role' and 'content' keys"
2106+
"All messages must be dictionaries or MemoryMessage objects"
20892107
)
20902108

2091-
# Get existing messages (already in proper dict format from get_working_memory)
2109+
# Get existing messages
20922110
existing_messages = []
20932111
if existing_memory and existing_memory.messages:
20942112
existing_messages = existing_memory.messages
20952113

2096-
final_messages = existing_messages + messages
2114+
final_messages = existing_messages + converted_messages
20972115

20982116
# Create updated working memory
20992117
working_memory = WorkingMemory(

agent-memory-client/agent_memory_client/models.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from datetime import datetime, timezone
99
from enum import Enum
10-
from typing import Any, Literal, TypedDict
10+
from typing import Any, Literal
1111

1212
from pydantic import BaseModel, Field
1313
from ulid import ULID
@@ -48,11 +48,19 @@ class MemoryTypeEnum(str, Enum):
4848
MESSAGE = "message"
4949

5050

51-
class MemoryMessage(TypedDict):
51+
class MemoryMessage(BaseModel):
5252
"""A message in the memory system"""
5353

5454
role: str
5555
content: str
56+
id: str = Field(
57+
default_factory=lambda: str(ULID()),
58+
description="Unique identifier for the message (auto-generated)",
59+
)
60+
persisted_at: datetime | None = Field(
61+
default=None,
62+
description="Server-assigned timestamp when message was persisted to long-term storage",
63+
)
5664

5765

5866
class MemoryRecord(BaseModel):
@@ -134,9 +142,9 @@ class WorkingMemory(BaseModel):
134142
"""Working memory for a session - contains both messages and structured memory records"""
135143

136144
# Support both message-based memory (conversation) and structured memory records
137-
messages: list[dict[str, Any]] = Field(
145+
messages: list[MemoryMessage] = Field(
138146
default_factory=list,
139-
description="Conversation messages (role/content pairs)",
147+
description="Conversation messages with tracking fields",
140148
)
141149
memories: list[MemoryRecord | ClientMemoryRecord] = Field(
142150
default_factory=list,

agent_memory_server/api.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from fastapi import APIRouter, Depends, HTTPException, Query
33
from mcp.server.fastmcp.prompts import base
44
from mcp.types import TextContent
5-
from ulid import ULID
65

76
from agent_memory_server import long_term_memory, working_memory
87
from agent_memory_server.auth import UserInfo, get_current_user
@@ -18,7 +17,6 @@
1817
MemoryPromptRequest,
1918
MemoryPromptResponse,
2019
MemoryRecordResultsResponse,
21-
MemoryTypeEnum,
2220
ModelNameLiteral,
2321
SearchRequest,
2422
SessionListResponse,
@@ -329,35 +327,13 @@ async def put_working_memory(
329327
)
330328

331329
# Background tasks for long-term memory promotion and indexing (if enabled)
332-
if settings.long_term_memory:
330+
if settings.long_term_memory and updated_memory.memories:
333331
# Promote structured memories from working memory to long-term storage
334-
if updated_memory.memories:
335-
await background_tasks.add_task(
336-
long_term_memory.promote_working_memory_to_long_term,
337-
session_id,
338-
updated_memory.namespace,
339-
)
340-
341-
# Index message-based memories
342-
if updated_memory.messages:
343-
from agent_memory_server.models import MemoryRecord
344-
345-
memories = [
346-
MemoryRecord(
347-
id=str(ULID()),
348-
session_id=session_id,
349-
text=f"{msg.role}: {msg.content}",
350-
namespace=updated_memory.namespace,
351-
user_id=updated_memory.user_id,
352-
memory_type=MemoryTypeEnum.MESSAGE,
353-
)
354-
for msg in updated_memory.messages
355-
]
356-
357-
await background_tasks.add_task(
358-
long_term_memory.index_long_term_memories,
359-
memories,
360-
)
332+
await background_tasks.add_task(
333+
long_term_memory.promote_working_memory_to_long_term,
334+
session_id,
335+
updated_memory.namespace,
336+
)
361337

362338
return updated_memory
363339

agent_memory_server/long_term_memory.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,12 +1247,19 @@ async def promote_working_memory_to_long_term(
12471247
if memory.persisted_at is None
12481248
]
12491249

1250-
if not unpersisted_memories:
1251-
logger.debug(f"No unpersisted memories found in session {session_id}")
1250+
# Find unpersisted messages (similar to unpersisted memories)
1251+
unpersisted_messages = [
1252+
msg for msg in current_working_memory.messages if msg.persisted_at is None
1253+
]
1254+
1255+
if not unpersisted_memories and not unpersisted_messages:
1256+
logger.debug(
1257+
f"No unpersisted memories or messages found in session {session_id}"
1258+
)
12521259
return 0
12531260

12541261
logger.info(
1255-
f"Promoting {len(unpersisted_memories)} memories from session {session_id}"
1262+
f"Promoting {len(unpersisted_memories)} memories and {len(unpersisted_messages)} messages from session {session_id}"
12561263
)
12571264

12581265
promoted_count = 0
@@ -1317,10 +1324,57 @@ async def promote_working_memory_to_long_term(
13171324
)
13181325
updated_memories.extend(extracted_memories)
13191326

1327+
# Process unpersisted messages
1328+
updated_messages = []
1329+
for msg in current_working_memory.messages:
1330+
if msg.persisted_at is None:
1331+
# Generate ID if not present (backward compatibility)
1332+
if not msg.id:
1333+
msg.id = str(ULID())
1334+
1335+
memory_record = MemoryRecord(
1336+
id=msg.id,
1337+
session_id=session_id,
1338+
text=f"{msg.role}: {msg.content}",
1339+
namespace=namespace,
1340+
user_id=current_working_memory.user_id,
1341+
memory_type=MemoryTypeEnum.MESSAGE,
1342+
persisted_at=None,
1343+
)
1344+
1345+
# Apply same deduplication logic as structured memories
1346+
deduped_memory, was_overwrite = await deduplicate_by_id(
1347+
memory=memory_record,
1348+
redis_client=redis,
1349+
)
1350+
1351+
# Set persisted_at timestamp
1352+
current_memory = deduped_memory or memory_record
1353+
current_memory.persisted_at = datetime.now(UTC)
1354+
1355+
# Index in long-term storage
1356+
await index_long_term_memories(
1357+
[current_memory],
1358+
redis_client=redis,
1359+
deduplicate=False, # Already deduplicated by ID
1360+
)
1361+
1362+
# Update message with persisted_at timestamp
1363+
msg.persisted_at = current_memory.persisted_at
1364+
promoted_count += 1
1365+
1366+
if was_overwrite:
1367+
logger.info(f"Overwrote existing message with id {msg.id}")
1368+
else:
1369+
logger.info(f"Promoted new message with id {msg.id}")
1370+
1371+
updated_messages.append(msg)
1372+
13201373
# Update working memory with the new persisted_at timestamps and extracted memories
13211374
if promoted_count > 0 or extracted_memories:
13221375
updated_working_memory = current_working_memory.model_copy()
13231376
updated_working_memory.memories = updated_memories
1377+
updated_working_memory.messages = updated_messages
13241378
updated_working_memory.updated_at = datetime.now(UTC)
13251379

13261380
await working_memory.set_working_memory(
@@ -1329,7 +1383,7 @@ async def promote_working_memory_to_long_term(
13291383
)
13301384

13311385
logger.info(
1332-
f"Successfully promoted {promoted_count} memories to long-term storage"
1386+
f"Successfully promoted {promoted_count} memories and messages to long-term storage"
13331387
+ (
13341388
f" and extracted {len(extracted_memories)} new memories"
13351389
if extracted_memories

agent_memory_server/mcp.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,25 @@ async def set_working_memory(
659659
)
660660
```
661661
662-
4. Replace entire working memory state:
662+
4. Store conversation messages:
663+
```python
664+
set_working_memory(
665+
session_id="current_session",
666+
messages=[
667+
{
668+
"role": "user",
669+
"content": "What is the weather like?",
670+
"id": "msg_001" # Optional - auto-generated if not provided
671+
},
672+
{
673+
"role": "assistant",
674+
"content": "I'll check the weather for you."
675+
}
676+
]
677+
)
678+
```
679+
680+
5. Replace entire working memory state:
663681
```python
664682
set_working_memory(
665683
session_id="current_session",
@@ -673,7 +691,7 @@ async def set_working_memory(
673691
Args:
674692
session_id: The session ID to set memory for (required)
675693
memories: List of structured memory records (semantic, episodic, message types)
676-
messages: List of conversation messages (role/content pairs)
694+
messages: List of conversation messages (role/content pairs with optional id/persisted_at)
677695
context: Optional summary/context text
678696
data: Optional dictionary for storing arbitrary JSON data
679697
namespace: Optional namespace for scoping
@@ -712,12 +730,35 @@ async def set_working_memory(
712730

713731
processed_memories.append(processed_memory)
714732

733+
# Process messages to ensure proper format
734+
processed_messages = []
735+
if messages:
736+
for message in messages:
737+
# Handle both MemoryMessage objects and dict inputs
738+
if isinstance(message, MemoryMessage):
739+
# Already a MemoryMessage object, ensure persisted_at is None for new messages
740+
processed_message = message.model_copy(
741+
update={
742+
"persisted_at": None, # Mark as pending promotion
743+
}
744+
)
745+
else:
746+
# Dictionary input, convert to MemoryMessage
747+
message_dict = dict(message)
748+
# Remove id=None to allow auto-generation
749+
if message_dict.get("id") is None:
750+
message_dict.pop("id", None)
751+
message_dict["persisted_at"] = None
752+
processed_message = MemoryMessage(**message_dict)
753+
754+
processed_messages.append(processed_message)
755+
715756
# Create the working memory object
716757
working_memory_obj = WorkingMemory(
717758
session_id=session_id,
718759
namespace=memory_namespace,
719760
memories=processed_memories,
720-
messages=messages or [],
761+
messages=processed_messages,
721762
context=context,
722763
data=data or {},
723764
user_id=user_id,

agent_memory_server/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ class MemoryMessage(BaseModel):
6767

6868
role: str
6969
content: str
70+
id: str = Field(
71+
default_factory=lambda: str(ULID()),
72+
description="Unique identifier for the message (auto-generated if not provided)",
73+
)
74+
persisted_at: datetime | None = Field(
75+
default=None,
76+
description="Server-assigned timestamp when message was persisted to long-term storage",
77+
)
7078

7179

7280
class SessionListResponse(BaseModel):

tests/test_api.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,22 @@ async def test_get_memory(self, client, session):
8989

9090
data = response.json()
9191
response = WorkingMemoryResponse(**data)
92-
assert response.messages == [
93-
MemoryMessage(role="user", content="Hello"),
94-
MemoryMessage(role="assistant", content="Hi there"),
95-
]
92+
93+
# Check that we have 2 messages with correct roles and content
94+
assert len(response.messages) == 2
95+
96+
# Check message content and roles (IDs are auto-generated so we can't compare directly)
97+
message_contents = [msg.content for msg in response.messages]
98+
message_roles = [msg.role for msg in response.messages]
99+
assert "Hello" in message_contents
100+
assert "Hi there" in message_contents
101+
assert "user" in message_roles
102+
assert "assistant" in message_roles
103+
104+
# Check that all messages have IDs (auto-generated)
105+
for msg in response.messages:
106+
assert msg.id is not None
107+
assert len(msg.id) > 0
96108

97109
roles = [msg["role"] for msg in data["messages"]]
98110
contents = [msg["content"] for msg in data["messages"]]

tests/test_client_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ async def test_session_lifecycle(memory_test_client: MemoryAPIClient):
116116

117117
# Step 1: Create new session memory
118118
response = await memory_test_client.put_working_memory(session_id, memory)
119-
assert response.messages[0]["content"] == "Hello from the client!"
120-
assert response.messages[1]["content"] == "Hi there, I'm the memory server!"
119+
assert response.messages[0].content == "Hello from the client!"
120+
assert response.messages[1].content == "Hi there, I'm the memory server!"
121121
assert response.context == "This is a test session created by the API client."
122122

123123
# Next, mock GET response for retrieving session memory
@@ -132,8 +132,8 @@ async def test_session_lifecycle(memory_test_client: MemoryAPIClient):
132132
# Step 2: Retrieve the session memory
133133
session = await memory_test_client.get_working_memory(session_id)
134134
assert len(session.messages) == 2
135-
assert session.messages[0]["content"] == "Hello from the client!"
136-
assert session.messages[1]["content"] == "Hi there, I'm the memory server!"
135+
assert session.messages[0].content == "Hello from the client!"
136+
assert session.messages[1].content == "Hi there, I'm the memory server!"
137137
assert session.context == "This is a test session created by the API client."
138138

139139
# Mock list sessions

0 commit comments

Comments
 (0)