-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathworking_memory.py
More file actions
244 lines (204 loc) · 7.55 KB
/
working_memory.py
File metadata and controls
244 lines (204 loc) · 7.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""Working memory management for sessions."""
import json
import logging
import time
from datetime import UTC, datetime
from redis.asyncio import Redis
from agent_memory_server.models import MemoryMessage, MemoryRecord, WorkingMemory
from agent_memory_server.utils.keys import Keys
from agent_memory_server.utils.redis import get_redis_conn
logger = logging.getLogger(__name__)
def json_datetime_handler(obj):
"""JSON serializer for datetime objects."""
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
async def list_sessions(
redis,
limit: int = 10,
offset: int = 0,
namespace: str | None = None,
user_id: str | None = None,
) -> tuple[int, list[str]]:
"""
List sessions
Args:
redis: Redis client
limit: Maximum number of sessions to return
offset: Offset for pagination
namespace: Optional namespace filter
user_id: Optional user ID filter (not yet implemented - sessions are stored in sorted sets)
Returns:
Tuple of (total_count, session_ids)
Note:
The user_id parameter is accepted for API compatibility but filtering by user_id
is not yet implemented. This would require changing how sessions are stored to
enable efficient user_id-based filtering.
"""
# Calculate start and end indices (0-indexed start, inclusive end)
start = offset
end = offset + limit - 1
# TODO: This should take a user_id
sessions_key = Keys.sessions_key(namespace=namespace)
async with redis.pipeline() as pipe:
pipe.zcard(sessions_key)
pipe.zrange(sessions_key, start, end)
total, session_ids = await pipe.execute()
return total, [
s.decode("utf-8") if isinstance(s, bytes) else s for s in session_ids
]
async def get_working_memory(
session_id: str,
user_id: str | None = None,
namespace: str | None = None,
redis_client: Redis | None = None,
) -> WorkingMemory | None:
"""
Get working memory for a session.
Args:
session_id: The session ID
namespace: Optional namespace for the session
redis_client: Optional Redis client
Returns:
WorkingMemory object or None if not found
"""
if not redis_client:
redis_client = await get_redis_conn()
key = Keys.working_memory_key(
session_id=session_id,
user_id=user_id,
namespace=namespace,
)
try:
data = await redis_client.get(key)
if not data:
logger.debug(
f"No working memory found for parameters: {session_id}, {user_id}, {namespace}"
)
return None
# Parse the JSON data
working_memory_data = json.loads(data)
# Convert memory records back to MemoryRecord objects
memories = []
for memory_data in working_memory_data.get("memories", []):
memory = MemoryRecord(**memory_data)
memories.append(memory)
# Convert messages back to MemoryMessage objects
messages = []
for message_data in working_memory_data.get("messages", []):
message = MemoryMessage(**message_data)
messages.append(message)
return WorkingMemory(
messages=messages,
memories=memories,
context=working_memory_data.get("context"),
user_id=working_memory_data.get("user_id"),
tokens=working_memory_data.get("tokens", 0),
session_id=session_id,
namespace=namespace,
ttl_seconds=working_memory_data.get("ttl_seconds", None),
data=working_memory_data.get("data") or {},
last_accessed=datetime.fromtimestamp(
working_memory_data.get("last_accessed", int(time.time())), UTC
),
created_at=datetime.fromtimestamp(
working_memory_data.get("created_at", int(time.time())), UTC
),
updated_at=datetime.fromtimestamp(
working_memory_data.get("updated_at", int(time.time())), UTC
),
)
except Exception as e:
logger.error(f"Error getting working memory for session {session_id}: {e}")
return None
async def set_working_memory(
working_memory: WorkingMemory,
redis_client: Redis | None = None,
) -> None:
"""
Set working memory for a session with TTL.
Args:
working_memory: WorkingMemory object to store
redis_client: Optional Redis client
"""
if not redis_client:
redis_client = await get_redis_conn()
# Validate that all memories have id (Stage 3 requirement)
for memory in working_memory.memories:
if not memory.id:
raise ValueError("All memory records in working memory must have an ID")
key = Keys.working_memory_key(
session_id=working_memory.session_id,
user_id=working_memory.user_id,
namespace=working_memory.namespace,
)
# Update the updated_at timestamp
working_memory.updated_at = datetime.now(UTC)
# Convert to JSON-serializable format with timestamp conversion
data = {
"messages": [
message.model_dump(mode="json") for message in working_memory.messages
],
"memories": [
memory.model_dump(mode="json") for memory in working_memory.memories
],
"context": working_memory.context,
"user_id": working_memory.user_id,
"tokens": working_memory.tokens,
"session_id": working_memory.session_id,
"namespace": working_memory.namespace,
"ttl_seconds": working_memory.ttl_seconds,
"data": working_memory.data or {},
"last_accessed": int(working_memory.last_accessed.timestamp()),
"created_at": int(working_memory.created_at.timestamp()),
"updated_at": int(working_memory.updated_at.timestamp()),
}
try:
if working_memory.ttl_seconds is not None:
# Store with TTL
await redis_client.setex(
key,
working_memory.ttl_seconds,
json.dumps(data, default=json_datetime_handler),
)
logger.info(
f"Set working memory for session {working_memory.session_id} with TTL {working_memory.ttl_seconds}s"
)
else:
await redis_client.set(
key,
json.dumps(data, default=json_datetime_handler),
)
logger.info(
f"Set working memory for session {working_memory.session_id} with no TTL"
)
except Exception as e:
logger.error(
f"Error setting working memory for session {working_memory.session_id}: {e}"
)
raise
async def delete_working_memory(
session_id: str,
user_id: str | None = None,
namespace: str | None = None,
redis_client: Redis | None = None,
) -> None:
"""
Delete working memory for a session.
Args:
session_id: The session ID
user_id: Optional user ID for the session
namespace: Optional namespace for the session
redis_client: Optional Redis client
"""
if not redis_client:
redis_client = await get_redis_conn()
key = Keys.working_memory_key(
session_id=session_id, user_id=user_id, namespace=namespace
)
try:
await redis_client.delete(key)
logger.info(f"Deleted working memory for session {session_id}")
except Exception as e:
logger.error(f"Error deleting working memory for session {session_id}: {e}")
raise