Skip to content

Commit 22a2fdd

Browse files
committed
fix: build issues & refactor
1 parent 7377075 commit 22a2fdd

2 files changed

Lines changed: 104 additions & 71 deletions

File tree

examples/proactive/memory/platform/memorize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import aiohttp
44

5+
from ..config import memorize_config
6+
57
BASE_URL = "https://api.memu.so"
68
API_KEY = "your memu api key"
79
USER_ID = "claude_user"
810
AGENT_ID = "claude_agent"
911

10-
from ..config import memorize_config
11-
1212

1313
async def memorize(conversation_messages: list[dict[str, Any]]) -> str | None:
1414
payload = {

examples/proactive/proactive.py

Lines changed: 102 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,107 @@ async def trigger_memorize(messages: list[dict[str, any]]) -> bool:
2020
"""Background task to memorize conversation messages."""
2121
try:
2222
await memorize(messages)
23-
print("\n[Background] Memorization submitted.")
24-
return True
2523
except Exception as e:
2624
print(f"\n[Background] Memorization failed: {e!r}")
2725
return False
26+
else:
27+
print("\n[Background] Memorization submitted.")
28+
return True
29+
30+
31+
async def get_next_input(iteration: int) -> tuple[str | None, bool]:
32+
"""
33+
Get the next input for the conversation.
34+
35+
Returns:
36+
tuple of (input_text, should_break)
37+
- input_text: The user input or todo-based input, None if should continue
38+
- should_break: True if the loop should break
39+
"""
40+
if iteration == 0:
41+
return await get_user_input()
42+
43+
todos = await _get_todos()
44+
if todos:
45+
return f"Please continue with the following todos:\n{todos}", False
46+
47+
return await get_user_input()
48+
49+
50+
async def get_user_input() -> tuple[str | None, bool]:
51+
"""
52+
Get input from the user.
53+
54+
Returns:
55+
tuple of (input_text, should_break)
56+
"""
57+
try:
58+
user_input = input("\nYou: ").strip()
59+
except EOFError:
60+
return None, True
61+
62+
if not user_input:
63+
return None, False
64+
65+
if user_input.lower() in ("quit", "exit"):
66+
return None, True
67+
68+
return user_input, False
69+
70+
71+
async def process_response(client: ClaudeSDKClient) -> list[str]:
72+
"""Process the assistant response and return collected text parts."""
73+
assistant_text_parts: list[str] = []
74+
75+
async for message in client.receive_response():
76+
if isinstance(message, AssistantMessage):
77+
for block in message.content:
78+
if isinstance(block, TextBlock):
79+
print(f"Claude: {block.text}")
80+
assistant_text_parts.append(block.text)
81+
elif isinstance(message, ResultMessage):
82+
print(f"Result: {message.result}")
83+
84+
return assistant_text_parts
85+
86+
87+
async def check_and_memorize(conversation_messages: list[dict[str, any]]) -> None:
88+
"""Check if memorization threshold is reached and trigger if needed."""
89+
if len(conversation_messages) >= N_MESSAGES_MEMORIZE:
90+
print(f"\n[Info] Reached {N_MESSAGES_MEMORIZE} messages, triggering memorization...")
91+
success = await trigger_memorize(conversation_messages.copy())
92+
if success:
93+
conversation_messages.clear()
94+
95+
96+
async def run_conversation_loop(client: ClaudeSDKClient) -> list[dict[str, any]]:
97+
"""Run the main conversation loop."""
98+
conversation_messages: list[dict[str, any]] = []
99+
iteration = 0
100+
101+
while True:
102+
user_input, should_break = await get_next_input(iteration)
103+
104+
if should_break:
105+
break
106+
if user_input is None:
107+
continue
108+
109+
conversation_messages.append({"role": "user", "content": user_input})
110+
await client.query(user_input)
111+
112+
assistant_text_parts = await process_response(client)
113+
114+
if assistant_text_parts:
115+
conversation_messages.append({
116+
"role": "assistant",
117+
"content": "\n".join(assistant_text_parts),
118+
})
119+
120+
await check_and_memorize(conversation_messages)
121+
iteration += 1
122+
123+
return conversation_messages
28124

29125

30126
async def main():
@@ -36,79 +132,16 @@ async def main():
36132
],
37133
)
38134

39-
conversation_messages: list[dict[str, any]] = []
40-
pending_tasks: list[asyncio.Task] = []
41-
42135
print("Claude Autorun")
43136
print("Type 'quit' or 'exit' to end the session.")
44137
print("-" * 40)
45138

46-
round = 0
47139
async with ClaudeSDKClient(options=options) as client:
48-
while True:
49-
want_user_input = False
50-
51-
if round == 0:
52-
want_user_input = True
53-
else:
54-
todos = await _get_todos()
55-
if todos:
56-
user_input = f"Please continue with the following todos:\n{todos}"
57-
else:
58-
want_user_input = True
59-
60-
if want_user_input:
61-
try:
62-
user_input = input("\nYou: ").strip()
63-
except EOFError:
64-
break
65-
66-
if not user_input:
67-
continue
68-
69-
if user_input.lower() in ("quit", "exit"):
70-
break
71-
72-
# Record user message
73-
conversation_messages.append({"role": "user", "content": user_input})
74-
75-
# Send query to Claude
76-
await client.query(user_input)
77-
78-
# Collect assistant response
79-
assistant_text_parts: list[str] = []
80-
81-
async for message in client.receive_response():
82-
if isinstance(message, AssistantMessage):
83-
for block in message.content:
84-
if isinstance(block, TextBlock):
85-
print(f"Claude: {block.text}")
86-
assistant_text_parts.append(block.text)
87-
elif isinstance(message, ResultMessage):
88-
print(f"Result: {message.result}")
89-
90-
# Record assistant message
91-
if assistant_text_parts:
92-
conversation_messages.append({"role": "assistant", "content": "\n".join(assistant_text_parts)})
93-
94-
# Check if we should trigger memorization
95-
if len(conversation_messages) >= N_MESSAGES_MEMORIZE:
96-
print(f"\n[Info] Reached {N_MESSAGES_MEMORIZE} messages, triggering memorization...")
97-
success = await trigger_memorize(conversation_messages.copy())
98-
if success:
99-
conversation_messages.clear()
100-
101-
round += 1
102-
103-
# User quit - memorize remaining messages if any
104-
if conversation_messages:
105-
print("\n[Info] Session ended, memorizing remaining messages...")
106-
success = await trigger_memorize(conversation_messages.copy())
140+
remaining_messages = await run_conversation_loop(client)
107141

108-
# Wait for all pending memorization tasks to complete
109-
if pending_tasks:
110-
print("[Info] Waiting for memorization tasks to complete...")
111-
await asyncio.gather(*pending_tasks, return_exceptions=True)
142+
if remaining_messages:
143+
print("\n[Info] Session ended, memorizing remaining messages...")
144+
await trigger_memorize(remaining_messages.copy())
112145

113146
print("\nDone")
114147

0 commit comments

Comments
 (0)