Skip to content

Commit 159f7d4

Browse files
committed
Better summarization, memory extraction, etc.
1 parent 8168233 commit 159f7d4

File tree

11 files changed

+432
-294
lines changed

11 files changed

+432
-294
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,4 @@ libs/redis/docs/.Trash*
221221
.python-version
222222
.idea/*
223223
.vscode/settings.json
224+
.cursor

agent_memory_server/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ async def run_migrations():
7979
@click.option("--reload", is_flag=True, help="Enable auto-reload")
8080
def api(port: int, host: str, reload: bool):
8181
"""Run the REST API server."""
82-
from agent_memory_server.main import app, on_start_logger
82+
from agent_memory_server.main import on_start_logger
8383

8484
on_start_logger(port)
8585
uvicorn.run(
86-
app,
86+
"agent_memory_server.main:app",
8787
host=host,
8888
port=port,
8989
reload=reload,

agent_memory_server/client/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CreatedAt,
1414
Entities,
1515
LastAccessed,
16+
MemoryType,
1617
Namespace,
1718
SessionId,
1819
Topics,
@@ -273,6 +274,7 @@ async def search_long_term_memory(
273274
last_accessed: LastAccessed | dict[str, Any] | None = None,
274275
user_id: UserId | dict[str, Any] | None = None,
275276
distance_threshold: float | None = None,
277+
memory_type: MemoryType | dict[str, Any] | None = None,
276278
limit: int = 10,
277279
offset: int = 0,
278280
) -> LongTermMemoryResults:
@@ -313,6 +315,8 @@ async def search_long_term_memory(
313315
last_accessed = LastAccessed(**last_accessed)
314316
if isinstance(user_id, dict):
315317
user_id = UserId(**user_id)
318+
if isinstance(memory_type, dict):
319+
memory_type = MemoryType(**memory_type)
316320

317321
# Apply default namespace if needed and no namespace filter specified
318322
if namespace is None and self.config.default_namespace is not None:
@@ -328,6 +332,7 @@ async def search_long_term_memory(
328332
last_accessed=last_accessed,
329333
user_id=user_id,
330334
distance_threshold=distance_threshold,
335+
memory_type=memory_type,
331336
limit=limit,
332337
offset=offset,
333338
)

agent_memory_server/extraction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,4 +343,7 @@ async def extract_discrete_memories(redis: Redis | None = None):
343343
for new_memory in discrete_memories
344344
]
345345

346-
await index_long_term_memories(long_term_memories)
346+
await index_long_term_memories(
347+
long_term_memories,
348+
deduplicate=True,
349+
)

agent_memory_server/long_term_memory.py

Lines changed: 398 additions & 272 deletions
Large diffs are not rendered by default.

agent_memory_server/messages.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,24 @@ async def get_session_memory(
6666
if not session_exists:
6767
return None
6868

69-
# Retrieve messages and metadata
7069
async with redis.pipeline() as pipe:
7170
pipe.lrange(messages_key, -window_size, -1) # Get the most recent messages
7271
pipe.hgetall(metadata_key)
7372
messages_data, metadata = await pipe.execute()
7473

75-
# Parse messages
7674
messages = []
7775
for msg_data in messages_data:
7876
if isinstance(msg_data, bytes):
7977
msg_data = msg_data.decode("utf-8")
8078
msg = json.loads(msg_data)
8179
messages.append(MemoryMessage(**msg))
8280

83-
# Parse metadata
8481
metadata_dict = {}
8582
for k, v in metadata.items():
8683
key = k.decode("utf-8") if isinstance(k, bytes) else k
8784
value = v.decode("utf-8") if isinstance(v, bytes) else v
8885
metadata_dict[key] = value
8986

90-
# Create SessionMemory object
9187
return SessionMemory(messages=messages, **metadata_dict)
9288

9389

@@ -112,6 +108,7 @@ async def set_session_memory(
112108
messages_json = [json.dumps(msg.model_dump()) for msg in memory.messages]
113109
metadata = memory.model_dump(
114110
exclude_none=True,
111+
exclude_unset=True,
115112
exclude={"messages"},
116113
)
117114

@@ -138,11 +135,7 @@ async def set_session_memory(
138135

139136
# Check if window size is exceeded
140137
current_size = await redis.llen(messages_key) # type: ignore
141-
print(
142-
f"Current size: {current_size}", "Current window size: ", settings.window_size
143-
)
144138
if current_size > settings.window_size:
145-
print("Queuing summarizing session task")
146139
# Add summarization task
147140
await background_tasks.add_task(
148141
summarize_session,
@@ -152,7 +145,6 @@ async def set_session_memory(
152145
)
153146

154147
# If long-term memory is enabled, index messages
155-
print("Long-term memory is enabled: ", settings.long_term_memory)
156148
if settings.long_term_memory:
157149
memories = [
158150
LongTermMemory(
@@ -164,7 +156,6 @@ async def set_session_memory(
164156
for msg in memory.messages
165157
]
166158

167-
print("Adding a task")
168159
await background_tasks.add_task(
169160
index_long_term_memories,
170161
memories,

agent_memory_server/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CreatedAt,
99
Entities,
1010
LastAccessed,
11+
MemoryType,
1112
Namespace,
1213
SessionId,
1314
Topics,
@@ -213,6 +214,10 @@ class SearchPayload(BaseModel):
213214
default=None,
214215
description="Optional distance threshold to filter by",
215216
)
217+
memory_type: MemoryType | None = Field(
218+
default=None,
219+
description="Optional memory type to filter by",
220+
)
216221
limit: int = Field(
217222
default=10,
218223
ge=1,
@@ -250,4 +255,7 @@ def get_filters(self):
250255
if self.last_accessed is not None:
251256
filters["last_accessed"] = self.last_accessed
252257

258+
if self.memory_type is not None:
259+
filters["memory_type"] = self.memory_type
260+
253261
return filters

agent_memory_server/summarization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def _incremental_summary(
9898
response = await client.create_chat_completion(model, progressive_prompt)
9999

100100
# Extract completion text
101-
completion = response.choices[0]["message"]["content"]
101+
completion = response.choices[0].message.content
102102

103103
# Get token usage
104104
tokens_used = response.total_tokens
@@ -219,6 +219,7 @@ async def summarize_session(
219219
metadata["tokens"] = str(total_tokens)
220220

221221
pipe.hmset(metadata_key, mapping=metadata)
222+
print("Metadata: ", metadata_key, metadata)
222223

223224
# Messages that were summarized
224225
num_summarized = len(messages_to_summarize)

tests/test_long_term_memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,11 @@ def __init__(self, docs):
147147

148148
assert mock_index.query.call_count == 1
149149

150-
assert len(results.memories) == 2
150+
assert len(results.memories) == 1
151151
assert isinstance(results.memories[0], LongTermMemoryResult)
152152
assert results.memories[0].text == "Hello, world!"
153153
assert results.memories[0].dist == 0.25
154-
assert results.memories[1].text == "Hi there!"
155-
assert results.memories[1].dist == 0.75
154+
assert results.memories[0].memory_type == "message"
156155

157156

158157
@pytest.mark.requires_api_keys

tests/test_memory_compaction.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ async def test_merge_memories_with_llm(mock_openai_client, monkeypatch):
7070
},
7171
]
7272

73-
merged = await merge_memories_with_llm(
74-
memories, "hash", llm_client=mock_openai_client
75-
)
73+
merged = await merge_memories_with_llm(memories, llm_client=mock_openai_client)
7674
assert merged["text"] == "Merged content"
7775
assert merged["created_at"] == memories[1]["created_at"]
7876
assert merged["last_accessed"] == memories[1]["last_accessed"]

0 commit comments

Comments
 (0)