Skip to content

Commit 6c9900f

Browse files
committed
optimize the exception
1 parent f311845 commit 6c9900f

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

src/agentscope_runtime/engine/services/memory/redis_memory_service.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ async def search_memory(
139139
messages: list,
140140
filters: Optional[Dict[str, Any]] = None,
141141
) -> list:
142+
if not self._redis:
143+
raise RuntimeError("Redis connection is not available")
142144
key = self._user_key(user_id)
143145
if (
144146
not messages
@@ -192,9 +194,8 @@ async def search_memory(
192194

193195
# Refresh TTL on read to extend lifetime of actively used data,
194196
# if a TTL is configured and there is existing data for this key.
195-
ttl_seconds = getattr(self, "_ttl", None)
196-
if ttl_seconds and hash_keys:
197-
await self._redis.expire(key, ttl_seconds)
197+
if self._ttl_seconds is not None and hash_keys:
198+
await self._redis.expire(key, self._ttl_seconds)
198199

199200
return result
200201

@@ -211,6 +212,8 @@ async def list_memory(
211212
user_id: str,
212213
filters: Optional[Dict[str, Any]] = None,
213214
) -> list:
215+
if not self._redis:
216+
raise RuntimeError("Redis connection is not available")
214217
key = self._user_key(user_id)
215218
page_num = filters.get("page_num", 1) if filters else 1
216219
page_size = filters.get("page_size", 10) if filters else 10
@@ -236,7 +239,7 @@ async def list_memory(
236239

237240
# Refresh TTL on active use to keep memory alive,
238241
# mirroring get_session behavior
239-
if getattr(self, "_ttl_seconds", None):
242+
if self._ttl_seconds is not None and hash_keys:
240243
await self._redis.expire(key, self._ttl_seconds)
241244
return all_msgs[start_index:end_index]
242245

@@ -245,6 +248,8 @@ async def delete_memory(
245248
user_id: str,
246249
session_id: Optional[str] = None,
247250
) -> None:
251+
if not self._redis:
252+
raise RuntimeError("Redis connection is not available")
248253
key = self._user_key(user_id)
249254
if session_id:
250255
await self._redis.hdel(key, session_id)

src/agentscope_runtime/engine/services/session_history/redis_session_history_service.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ async def create_session(
106106
user_id: str,
107107
session_id: Optional[str] = None,
108108
) -> Session:
109+
if not self._redis:
110+
raise RuntimeError("Redis connection is not available")
109111
if session_id and session_id.strip():
110112
sid = session_id.strip()
111113
else:
@@ -127,6 +129,8 @@ async def get_session(
127129
user_id: str,
128130
session_id: str,
129131
) -> Optional[Session]:
132+
if not self._redis:
133+
raise RuntimeError("Redis connection is not available")
130134
key = self._session_key(user_id, session_id)
131135
session_json = await self._redis.get(key)
132136
if session_json is None:
@@ -141,6 +145,8 @@ async def get_session(
141145
return session
142146

143147
async def delete_session(self, user_id: str, session_id: str):
148+
if not self._redis:
149+
raise RuntimeError("Redis connection is not available")
144150
key = self._session_key(user_id, session_id)
145151
await self._redis.delete(key)
146152

@@ -150,6 +156,8 @@ async def list_sessions(self, user_id: str) -> list[Session]:
150156
Uses SCAN to find all session:{user_id}:* keys. Expired sessions
151157
naturally disappear as their keys expire, avoiding stale entries.
152158
"""
159+
if not self._redis:
160+
raise RuntimeError("Redis connection is not available")
153161
pattern = self._session_pattern(user_id)
154162
sessions = []
155163
cursor = 0
@@ -182,6 +190,8 @@ async def append_message(
182190
List[Dict[str, Any]],
183191
],
184192
):
193+
if not self._redis:
194+
raise RuntimeError("Redis connection is not available")
185195
if not isinstance(message, list):
186196
message = [message]
187197
norm_message = []
@@ -199,14 +209,16 @@ async def append_message(
199209

200210
session_json = await self._redis.get(key)
201211
if session_json is None:
202-
raise RuntimeError(
203-
f"Session {session_id} not found or has expired for user "
204-
f"{user_id}. Previous memory/state has been lost. "
205-
f"Please create a new session.",
212+
# Session expired or not found, treat as a new session
213+
# Create a new session with the current messages
214+
stored_session = Session(
215+
id=session_id,
216+
user_id=user_id,
217+
messages=norm_message.copy(),
206218
)
207-
208-
stored_session = self._session_from_json(session_json)
209-
stored_session.messages.extend(norm_message)
219+
else:
220+
stored_session = self._session_from_json(session_json)
221+
stored_session.messages.extend(norm_message)
210222

211223
# Limit the number of messages per session to prevent memory issues
212224
if self._max_messages_per_session is not None:

tests/unit/test_redis_session_history_service.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,25 @@ async def test_append_message(
264264
for i, msg in enumerate(stored_session.messages[2:]):
265265
assert msg.content == messages3[i].get("content")
266266

267-
# Test appending to a non-existent session (should raise RuntimeError)
267+
# Test appending to a non-existent session (should create new session)
268268
non_existent_session = Session(
269269
id="non_existent",
270270
user_id=user_id,
271271
messages=[],
272272
)
273-
# This should raise a RuntimeError indicating
274-
# the session is missing/expired
275-
with pytest.raises(RuntimeError, match="not found or has expired"):
276-
await session_history_service.append_message(
277-
non_existent_session,
278-
message1,
279-
)
273+
# This should not raise an error, but create a new session
274+
await session_history_service.append_message(
275+
non_existent_session,
276+
message1,
277+
)
278+
# Verify the session was created with the message
279+
retrieved_session = await session_history_service.get_session(
280+
user_id,
281+
"non_existent",
282+
)
283+
assert retrieved_session is not None
284+
assert len(retrieved_session.messages) == 1
285+
assert retrieved_session.messages[0].content == message1.get("content")
280286

281287

282288
@pytest.mark.asyncio

0 commit comments

Comments
 (0)