Skip to content

Commit 1ea8327

Browse files
committed
optimize session_index
1 parent f72f2e1 commit 1ea8327

File tree

2 files changed

+81
-89
lines changed

2 files changed

+81
-89
lines changed

src/agentscope_runtime/engine/services/session_history/redis_session_history_service.py

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ async def health(self) -> bool:
9191
def _session_key(self, user_id: str, session_id: str):
9292
return f"session:{user_id}:{session_id}"
9393

94-
def _index_key(self, user_id: str):
95-
return f"session_index:{user_id}"
94+
def _session_pattern(self, user_id: str):
95+
"""Generate the pattern for scanning session keys for a user."""
96+
return f"session:{user_id}:*"
9697

9798
def _session_to_json(self, session: Session) -> str:
9899
return session.model_dump_json()
@@ -114,7 +115,6 @@ async def create_session(
114115
key = self._session_key(user_id, sid)
115116

116117
await self._redis.set(key, self._session_to_json(session))
117-
await self._redis.sadd(self._index_key(user_id), sid)
118118

119119
# Set TTL for the session key if configured
120120
if self._ttl_seconds is not None:
@@ -130,15 +130,7 @@ async def get_session(
130130
key = self._session_key(user_id, session_id)
131131
session_json = await self._redis.get(key)
132132
if session_json is None:
133-
session = Session(id=session_id, user_id=user_id)
134-
await self._redis.set(key, self._session_to_json(session))
135-
await self._redis.sadd(self._index_key(user_id), session_id)
136-
137-
# Set TTL for the session key if configured
138-
if self._ttl_seconds is not None:
139-
await self._redis.expire(key, self._ttl_seconds)
140-
141-
return session
133+
return None
142134

143135
session = self._session_from_json(session_json)
144136

@@ -151,19 +143,33 @@ async def get_session(
151143
async def delete_session(self, user_id: str, session_id: str):
152144
key = self._session_key(user_id, session_id)
153145
await self._redis.delete(key)
154-
await self._redis.srem(self._index_key(user_id), session_id)
155146

156147
async def list_sessions(self, user_id: str) -> list[Session]:
157-
idx_key = self._index_key(user_id)
158-
session_ids = await self._redis.smembers(idx_key)
148+
"""List all sessions for a user by scanning session keys.
149+
150+
Uses SCAN to find all session:{user_id}:* keys. Expired sessions
151+
naturally disappear as their keys expire, avoiding stale entries.
152+
"""
153+
pattern = self._session_pattern(user_id)
159154
sessions = []
160-
for sid in session_ids:
161-
key = self._session_key(user_id, sid)
162-
session_json = await self._redis.get(key)
163-
if session_json:
164-
session = self._session_from_json(session_json)
165-
session.messages = []
166-
sessions.append(session)
155+
cursor = 0
156+
157+
while True:
158+
cursor, keys = await self._redis.scan(
159+
cursor,
160+
match=pattern,
161+
count=100,
162+
)
163+
for key in keys:
164+
session_json = await self._redis.get(key)
165+
if session_json:
166+
session = self._session_from_json(session_json)
167+
session.messages = []
168+
sessions.append(session)
169+
170+
if cursor == 0:
171+
break
172+
167173
return sessions
168174

169175
async def append_message(
@@ -192,49 +198,54 @@ async def append_message(
192198
key = self._session_key(user_id, session_id)
193199

194200
session_json = await self._redis.get(key)
195-
if session_json:
196-
stored_session = self._session_from_json(session_json)
197-
stored_session.messages.extend(norm_message)
198-
199-
# Limit the number of messages per session to prevent memory issues
200-
if self._max_messages_per_session is not None:
201-
if (
202-
len(stored_session.messages)
203-
> self._max_messages_per_session
204-
):
205-
# Keep only the most recent messages
206-
stored_session.messages = stored_session.messages[
207-
-self._max_messages_per_session :
208-
]
209-
210-
await self._redis.set(key, self._session_to_json(stored_session))
211-
await self._redis.sadd(self._index_key(user_id), session_id)
212-
213-
# Set TTL for the session key if configured
214-
if self._ttl_seconds is not None:
215-
await self._redis.expire(key, self._ttl_seconds)
216-
else:
217-
print(
218-
f"Warning: Session {session.id} not found in storage for "
219-
f"append_message.",
201+
if session_json is None:
202+
raise RuntimeError(
203+
f"Session {session_id} not found or has expired for user "
204+
f"{user_id}. Previous memory/state has been lost. "
205+
f"Please create a new session.",
220206
)
221207

208+
stored_session = self._session_from_json(session_json)
209+
stored_session.messages.extend(norm_message)
210+
211+
# Limit the number of messages per session to prevent memory issues
212+
if self._max_messages_per_session is not None:
213+
if len(stored_session.messages) > self._max_messages_per_session:
214+
# Keep only the most recent messages
215+
stored_session.messages = stored_session.messages[
216+
-self._max_messages_per_session :
217+
]
218+
219+
await self._redis.set(key, self._session_to_json(stored_session))
220+
221+
# Set TTL for the session key if configured
222+
if self._ttl_seconds is not None:
223+
await self._redis.expire(key, self._ttl_seconds)
224+
222225
async def delete_user_sessions(self, user_id: str) -> None:
223226
"""
224227
Deletes all session history data for a specific user.
225228
229+
Uses SCAN to find all session keys for the user and deletes them.
230+
226231
Args:
227232
user_id (str): The ID of the user whose session history data should
228233
be deleted
229234
"""
230235
if not self._redis:
231236
raise RuntimeError("Redis connection is not available")
232237

233-
index_key = self._index_key(user_id)
234-
session_ids = await self._redis.smembers(index_key)
238+
pattern = self._session_pattern(user_id)
239+
cursor = 0
235240

236-
for session_id in session_ids:
237-
key = self._session_key(user_id, session_id)
238-
await self._redis.delete(key)
241+
while True:
242+
cursor, keys = await self._redis.scan(
243+
cursor,
244+
match=pattern,
245+
count=100,
246+
)
247+
if keys:
248+
await self._redis.delete(*keys)
239249

240-
await self._redis.delete(index_key)
250+
if cursor == 0:
251+
break

tests/unit/test_redis_session_history_service.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,19 @@ async def test_get_session(
118118
assert refetched_session is not None
119119
assert refetched_session.messages == []
120120

121-
# Test getting a non-existent session (should create a new one)
121+
# Test getting a non-existent session (should return None)
122122
non_existent_session = await session_history_service.get_session(
123123
user_id,
124124
"non_existent_id",
125125
)
126-
assert non_existent_session is not None
127-
assert non_existent_session.id == "non_existent_id"
128-
assert non_existent_session.user_id == user_id
129-
assert non_existent_session.messages == []
126+
assert non_existent_session is None
130127

131-
# Test getting a session for a different user (should create a new one)
128+
# Test getting a session for a different user (should return None)
132129
other_user_session = await session_history_service.get_session(
133130
"other_user",
134131
created_session.id,
135132
)
136-
assert other_user_session is not None
137-
assert other_user_session.id == created_session.id
138-
assert other_user_session.user_id == "other_user"
139-
assert other_user_session.messages == []
133+
assert other_user_session is None
140134

141135

142136
@pytest.mark.asyncio
@@ -156,17 +150,12 @@ async def test_delete_session(
156150

157151
await session_history_service.delete_session(user_id, session.id)
158152

159-
# Ensure session is deleted - get_session will create a new empty session
153+
# Ensure session is deleted - get_session should return None
160154
retrieved_session = await session_history_service.get_session(
161155
user_id,
162156
session.id,
163157
)
164-
assert retrieved_session is not None
165-
assert retrieved_session.id == session.id
166-
assert retrieved_session.user_id == user_id
167-
assert (
168-
retrieved_session.messages == []
169-
) # Should be empty as it's a new session
158+
assert retrieved_session is None
170159

171160
# Test deleting a non-existent session (should not raise error)
172161
await session_history_service.delete_session(user_id, "non_existent_id")
@@ -275,26 +264,19 @@ async def test_append_message(
275264
for i, msg in enumerate(stored_session.messages[2:]):
276265
assert msg.content == messages3[i].get("content")
277266

278-
# Test appending to a non-existent session
267+
# Test appending to a non-existent session (should raise RuntimeError)
279268
non_existent_session = Session(
280269
id="non_existent",
281270
user_id=user_id,
282271
messages=[],
283272
)
284-
# This should not raise an error, but print a warning.
285-
await session_history_service.append_message(
286-
non_existent_session,
287-
message1,
288-
)
289-
# get_session will create a new session, not the one we tried to append to
290-
retrieved_session = await session_history_service.get_session(
291-
user_id,
292-
"non_existent",
293-
)
294-
assert retrieved_session is not None
295-
assert (
296-
retrieved_session.messages == []
297-
) # Empty as it's a newly created session
273+
# This should raise a RuntimeError indicating
274+
# the session is missing/expired
275+
with pytest.raises(RuntimeError, match="not found or has expired"):
276+
await session_history_service.append_message(
277+
non_existent_session,
278+
message1,
279+
)
298280

299281

300282
@pytest.mark.asyncio
@@ -347,12 +329,11 @@ async def test_ttl_expiration():
347329
key_exists = await fake_redis.exists(key)
348330
assert key_exists == 0, "Key should be expired and deleted"
349331

350-
# Verify get_session creates a new session after expiry
332+
# Verify get_session returns None after expiry
351333
retrieved_after_expiry = await service.get_session(user_id, session.id)
352-
assert retrieved_after_expiry is not None
353334
assert (
354-
retrieved_after_expiry.messages == []
355-
), "Session should be empty after expiry"
335+
retrieved_after_expiry is None
336+
), "Session should return None after expiry"
356337

357338
finally:
358339
await service.stop()

0 commit comments

Comments
 (0)