Skip to content

Commit dd9b76e

Browse files
committed
Test fixes for ttls
1 parent f1da330 commit dd9b76e

File tree

3 files changed

+194
-9
lines changed

3 files changed

+194
-9
lines changed

.github/workflows/python-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ jobs:
3535
fail-fast: false
3636
matrix:
3737
python-version: [3.12]
38-
redis-version: ['6.2.6-v9', '8.0.3', 'latest']
38+
redis-version: ['redis-stack:6.2.6-v9', 'redis:8.0.3', 'redis:latest']
3939

4040
steps:
4141
- uses: actions/checkout@v3
4242

4343
- name: Set Redis image name
4444
run: |
45-
echo "REDIS_IMAGE=redis:${{ matrix.redis-version }}" >> $GITHUB_ENV
45+
echo "REDIS_IMAGE=${{ matrix.redis-version }}" >> $GITHUB_ENV
4646
4747
- name: Set up Python
4848
uses: actions/setup-python@v4

agent_memory_server/working_memory.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,23 +188,24 @@ async def set_working_memory(
188188
}
189189

190190
try:
191-
if working_memory.ttl_seconds:
191+
if working_memory.ttl_seconds is not None:
192192
# Store with TTL
193193
await redis_client.setex(
194194
key,
195195
working_memory.ttl_seconds,
196-
json.dumps(
197-
data, default=json_datetime_handler
198-
), # Add custom handler for any remaining datetime objects
196+
json.dumps(data, default=json_datetime_handler),
197+
)
198+
logger.info(
199+
f"Set working memory for session {working_memory.session_id} with TTL {working_memory.ttl_seconds}s"
199200
)
200201
else:
201202
await redis_client.set(
202203
key,
203204
json.dumps(data, default=json_datetime_handler),
204205
)
205-
logger.info(
206-
f"Set working memory for session {working_memory.session_id} with TTL {working_memory.ttl_seconds}s"
207-
)
206+
logger.info(
207+
f"Set working memory for session {working_memory.session_id} with no TTL"
208+
)
208209
except Exception as e:
209210
logger.error(
210211
f"Error setting working memory for session {working_memory.session_id}: {e}"

tests/test_working_memory.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for working memory functionality."""
22

3+
import asyncio
4+
35
import pytest
46
from pydantic import ValidationError
57

68
from agent_memory_server.models import MemoryRecord, MemoryTypeEnum, WorkingMemory
9+
from agent_memory_server.utils.keys import Keys
710
from agent_memory_server.working_memory import (
811
delete_working_memory,
912
get_working_memory,
@@ -60,6 +63,7 @@ async def test_set_and_get_working_memory(self, async_redis_client):
6063
assert retrieved_mem.memories[0].id == "client-1"
6164
assert retrieved_mem.memories[1].text == "User is working on a Python project"
6265
assert retrieved_mem.memories[1].id == "client-2"
66+
assert retrieved_mem.ttl_seconds == 1800 # Verify TTL is preserved
6367

6468
@pytest.mark.asyncio
6569
async def test_get_nonexistent_working_memory(self, async_redis_client):
@@ -155,3 +159,183 @@ async def test_working_memory_validation(self, async_redis_client):
155159
assert retrieved is not None
156160
assert len(retrieved.memories) == 1
157161
assert retrieved.memories[0].id == "test-memory-1"
162+
163+
@pytest.mark.asyncio
164+
async def test_working_memory_ttl_none(self, async_redis_client):
165+
"""Test working memory without TTL (persistent)"""
166+
session_id = "test-session-no-ttl"
167+
namespace = "test-namespace"
168+
169+
memories = [
170+
MemoryRecord(
171+
text="Persistent memory",
172+
id="persistent-1",
173+
memory_type=MemoryTypeEnum.SEMANTIC,
174+
),
175+
]
176+
177+
working_mem = WorkingMemory(
178+
memories=memories,
179+
session_id=session_id,
180+
namespace=namespace,
181+
ttl_seconds=None, # No TTL - should be persistent
182+
)
183+
184+
await set_working_memory(working_mem, redis_client=async_redis_client)
185+
186+
# Get working memory and verify TTL is None
187+
retrieved_mem = await get_working_memory(
188+
session_id=session_id,
189+
namespace=namespace,
190+
redis_client=async_redis_client,
191+
)
192+
193+
assert retrieved_mem is not None
194+
assert retrieved_mem.ttl_seconds is None
195+
196+
# Verify the Redis key has no TTL set (-1 means no TTL)
197+
key = Keys.working_memory_key(
198+
session_id=session_id,
199+
namespace=namespace,
200+
)
201+
ttl = await async_redis_client.ttl(key)
202+
assert ttl == -1 # No TTL set
203+
204+
@pytest.mark.asyncio
205+
async def test_working_memory_ttl_set(self, async_redis_client):
206+
"""Test working memory with TTL set"""
207+
session_id = "test-session-with-ttl"
208+
namespace = "test-namespace"
209+
210+
memories = [
211+
MemoryRecord(
212+
text="Memory with TTL",
213+
id="ttl-memory-1",
214+
memory_type=MemoryTypeEnum.SEMANTIC,
215+
),
216+
]
217+
218+
ttl_seconds = 60 # 1 minute
219+
working_mem = WorkingMemory(
220+
memories=memories,
221+
session_id=session_id,
222+
namespace=namespace,
223+
ttl_seconds=ttl_seconds,
224+
)
225+
226+
await set_working_memory(working_mem, redis_client=async_redis_client)
227+
228+
# Get working memory and verify TTL is preserved
229+
retrieved_mem = await get_working_memory(
230+
session_id=session_id,
231+
namespace=namespace,
232+
redis_client=async_redis_client,
233+
)
234+
235+
assert retrieved_mem is not None
236+
assert retrieved_mem.ttl_seconds == ttl_seconds
237+
238+
# Verify the Redis key has TTL set (should be <= 60 seconds)
239+
key = Keys.working_memory_key(
240+
session_id=session_id,
241+
namespace=namespace,
242+
)
243+
ttl = await async_redis_client.ttl(key)
244+
assert 0 < ttl <= ttl_seconds
245+
246+
@pytest.mark.asyncio
247+
async def test_working_memory_ttl_expiration(self, async_redis_client):
248+
"""Test working memory expires after TTL"""
249+
session_id = "test-session-expire"
250+
namespace = "test-namespace"
251+
252+
memories = [
253+
MemoryRecord(
254+
text="Memory that expires",
255+
id="expire-memory-1",
256+
memory_type=MemoryTypeEnum.SEMANTIC,
257+
),
258+
]
259+
260+
ttl_seconds = 1 # 1 second
261+
working_mem = WorkingMemory(
262+
memories=memories,
263+
session_id=session_id,
264+
namespace=namespace,
265+
ttl_seconds=ttl_seconds,
266+
)
267+
268+
await set_working_memory(working_mem, redis_client=async_redis_client)
269+
270+
# Verify it exists immediately
271+
retrieved_mem = await get_working_memory(
272+
session_id=session_id,
273+
namespace=namespace,
274+
redis_client=async_redis_client,
275+
)
276+
assert retrieved_mem is not None
277+
278+
# Wait for TTL to expire
279+
await asyncio.sleep(1.1)
280+
281+
# Verify it's gone after TTL
282+
retrieved_mem = await get_working_memory(
283+
session_id=session_id,
284+
namespace=namespace,
285+
redis_client=async_redis_client,
286+
)
287+
assert retrieved_mem is None
288+
289+
@pytest.mark.asyncio
290+
async def test_working_memory_ttl_update_preserves_ttl(self, async_redis_client):
291+
"""Test that updating working memory preserves TTL"""
292+
session_id = "test-session-update-ttl"
293+
namespace = "test-namespace"
294+
295+
memories = [
296+
MemoryRecord(
297+
text="Original memory",
298+
id="original-memory-1",
299+
memory_type=MemoryTypeEnum.SEMANTIC,
300+
),
301+
]
302+
303+
ttl_seconds = 120 # 2 minutes
304+
working_mem = WorkingMemory(
305+
memories=memories,
306+
session_id=session_id,
307+
namespace=namespace,
308+
ttl_seconds=ttl_seconds,
309+
)
310+
311+
await set_working_memory(working_mem, redis_client=async_redis_client)
312+
313+
# Update the working memory
314+
working_mem.memories.append(
315+
MemoryRecord(
316+
text="Updated memory",
317+
id="updated-memory-1",
318+
memory_type=MemoryTypeEnum.SEMANTIC,
319+
)
320+
)
321+
322+
await set_working_memory(working_mem, redis_client=async_redis_client)
323+
324+
# Get updated working memory and verify TTL is preserved
325+
retrieved_mem = await get_working_memory(
326+
session_id=session_id,
327+
namespace=namespace,
328+
redis_client=async_redis_client,
329+
)
330+
331+
assert retrieved_mem is not None
332+
assert retrieved_mem.ttl_seconds == ttl_seconds
333+
assert len(retrieved_mem.memories) == 2
334+
335+
# Verify the Redis key still has TTL set
336+
key = Keys.working_memory_key(
337+
session_id=session_id,
338+
namespace=namespace,
339+
)
340+
ttl = await async_redis_client.ttl(key)
341+
assert 0 < ttl <= ttl_seconds

0 commit comments

Comments
 (0)