Skip to content

Commit 80751a4

Browse files
committed
fix: memory type & proactive example
1 parent 200f47a commit 80751a4

5 files changed

Lines changed: 439 additions & 20 deletions

File tree

examples/proactive/memory/local/memorize.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from collections.abc import Awaitable
34
from pathlib import Path
45
from typing import Any
56

@@ -12,7 +13,7 @@
1213
USER_ID = "claude_user"
1314

1415

15-
async def dump_conversation_resource(
16+
def dump_conversation_resource(
1617
conversation_messages: list[dict[str, Any]],
1718
) -> str:
1819
resource_data = {
@@ -32,7 +33,7 @@ async def dump_conversation_resource(
3233
return resource_url.as_posix()
3334

3435

35-
async def memorize(conversation_messages: list[dict[str, Any]]) -> str | None:
36+
def memorize(conversation_messages: list[dict[str, Any]]) -> Awaitable[dict[str, Any]]:
3637
api_key = os.getenv("OPENAI_API_KEY")
3738
if not api_key:
3839
msg = "Please set OPENAI_API_KEY environment variable"
@@ -48,8 +49,5 @@ async def memorize(conversation_messages: list[dict[str, Any]]) -> str | None:
4849
memorize_config=memorize_config,
4950
)
5051

51-
resource_url = await dump_conversation_resource(conversation_messages)
52-
result = await memory_service.memorize(
53-
resource_url=resource_url, modality="conversation", user={"user_id": USER_ID}
54-
)
55-
return result
52+
resource_url = dump_conversation_resource(conversation_messages)
53+
return memory_service.memorize(resource_url=resource_url, modality="conversation", user={"user_id": USER_ID})

examples/proactive/proactive.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,23 @@
1414
# os.environ["ANTHROPIC_API_KEY"] = ""
1515

1616
N_MESSAGES_MEMORIZE = 2
17+
RUNNING_MEMORIZATION: asyncio.Task | None = None
1718

1819

1920
async def trigger_memorize(messages: list[dict[str, any]]) -> bool:
20-
"""Background task to memorize conversation messages."""
21+
"""Create a background task to memorize conversation messages.
22+
23+
Returns True if the task was successfully created and registered.
24+
"""
25+
global RUNNING_MEMORIZATION
2126
try:
22-
await memorize(messages)
27+
memorize_awaitable = memorize(messages)
28+
RUNNING_MEMORIZATION = asyncio.create_task(memorize_awaitable)
2329
except Exception as e:
24-
print(f"\n[Background] Memorization failed: {e!r}")
30+
print(f"\n[Memory] Memorization initialization failed: {e!r}")
2531
return False
2632
else:
27-
print("\n[Background] Memorization submitted.")
33+
print("\n[Memory] Memorization task submitted.")
2834
return True
2935

3036

@@ -85,12 +91,31 @@ async def process_response(client: ClaudeSDKClient) -> list[str]:
8591

8692

8793
async def check_and_memorize(conversation_messages: list[dict[str, any]]) -> None:
88-
"""Check if memorization threshold is reached and trigger if needed."""
89-
if len(conversation_messages) >= N_MESSAGES_MEMORIZE:
90-
print(f"\n[Info] Reached {N_MESSAGES_MEMORIZE} messages, triggering memorization...")
91-
success = await trigger_memorize(conversation_messages.copy())
92-
if success:
93-
conversation_messages.clear()
94+
"""Check if memorization threshold is reached and trigger if needed.
95+
96+
Skips triggering if a previous memorization task is still running.
97+
"""
98+
global RUNNING_MEMORIZATION
99+
100+
if len(conversation_messages) < N_MESSAGES_MEMORIZE:
101+
return
102+
103+
# Check if there's a running memorization task
104+
if RUNNING_MEMORIZATION is not None:
105+
if not RUNNING_MEMORIZATION.done():
106+
print("\n[Info] Have running memorization, skipping...")
107+
return
108+
# Previous task completed, check for exceptions
109+
try:
110+
RUNNING_MEMORIZATION.result()
111+
except Exception as e:
112+
print(f"\n[Memory] Memorization failed: {e!r}")
113+
RUNNING_MEMORIZATION = None
114+
115+
print(f"\n[Info] Reached {N_MESSAGES_MEMORIZE} messages, triggering memorization...")
116+
success = await trigger_memorize(conversation_messages.copy())
117+
if success:
118+
conversation_messages.clear()
94119

95120

96121
async def run_conversation_loop(client: ClaudeSDKClient) -> list[dict[str, any]]:
@@ -139,9 +164,28 @@ async def main():
139164
async with ClaudeSDKClient(options=options) as client:
140165
remaining_messages = await run_conversation_loop(client)
141166

167+
# Wait for any running memorization task to complete
168+
global RUNNING_MEMORIZATION
169+
if RUNNING_MEMORIZATION is not None and not RUNNING_MEMORIZATION.done():
170+
print("\n[Info] Waiting for running memorization task to complete...")
171+
try:
172+
await RUNNING_MEMORIZATION
173+
print("\n[Memory] Running memorization completed successfully.")
174+
except Exception as e:
175+
print(f"\n[Memory] Running memorization failed: {e!r}")
176+
RUNNING_MEMORIZATION = None
177+
178+
# Memorize remaining messages and wait for completion
142179
if remaining_messages:
143180
print("\n[Info] Session ended, memorizing remaining messages...")
144-
await trigger_memorize(remaining_messages.copy())
181+
success = await trigger_memorize(remaining_messages.copy())
182+
if success and RUNNING_MEMORIZATION is not None:
183+
print("\n[Info] Waiting for final memorization to complete...")
184+
try:
185+
await RUNNING_MEMORIZATION
186+
print("\n[Memory] Final memorization completed successfully.")
187+
except Exception as e:
188+
print(f"\n[Memory] Final memorization failed: {e!r}")
145189

146190
print("\nDone")
147191

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ test = [
6868
[project.optional-dependencies]
6969
postgres = ["pgvector>=0.3.4", "sqlalchemy[postgresql-psycopgbinary]>=2.0.36"]
7070
langgraph = ["langgraph>=0.0.10", "langchain-core>=0.1.0"]
71+
claude = ["claude-agent-sdk>=0.1.24"]
7172

7273
[project.urls]
7374
"Homepage" = "https://github.com/NevaMind-AI/MemU"

src/memu/database/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Resource(BaseRecord):
2828

2929
class MemoryItem(BaseRecord):
3030
resource_id: str | None
31-
memory_type: MemoryType
31+
memory_type: str
3232
summary: str
3333
embedding: list[float] | None = None
3434
happened_at: datetime | None = None

0 commit comments

Comments
 (0)