@@ -91,8 +91,9 @@ async def health(self) -> bool:
9191 def _session_key (self , user_id : str , session_id : str ):
9292 return f"session:{ user_id } :{ session_id } "
9393
94- def _index_key (self , user_id : str ):
95- return f"session_index:{ user_id } "
94+ def _session_pattern (self , user_id : str ):
95+ """Generate the pattern for scanning session keys for a user."""
96+ return f"session:{ user_id } :*"
9697
9798 def _session_to_json (self , session : Session ) -> str :
9899 return session .model_dump_json ()
@@ -114,7 +115,6 @@ async def create_session(
114115 key = self ._session_key (user_id , sid )
115116
116117 await self ._redis .set (key , self ._session_to_json (session ))
117- await self ._redis .sadd (self ._index_key (user_id ), sid )
118118
119119 # Set TTL for the session key if configured
120120 if self ._ttl_seconds is not None :
@@ -130,15 +130,7 @@ async def get_session(
130130 key = self ._session_key (user_id , session_id )
131131 session_json = await self ._redis .get (key )
132132 if session_json is None :
133- session = Session (id = session_id , user_id = user_id )
134- await self ._redis .set (key , self ._session_to_json (session ))
135- await self ._redis .sadd (self ._index_key (user_id ), session_id )
136-
137- # Set TTL for the session key if configured
138- if self ._ttl_seconds is not None :
139- await self ._redis .expire (key , self ._ttl_seconds )
140-
141- return session
133+ return None
142134
143135 session = self ._session_from_json (session_json )
144136
@@ -151,19 +143,33 @@ async def get_session(
151143 async def delete_session (self , user_id : str , session_id : str ):
152144 key = self ._session_key (user_id , session_id )
153145 await self ._redis .delete (key )
154- await self ._redis .srem (self ._index_key (user_id ), session_id )
155146
156147 async def list_sessions (self , user_id : str ) -> list [Session ]:
157- idx_key = self ._index_key (user_id )
158- session_ids = await self ._redis .smembers (idx_key )
148+ """List all sessions for a user by scanning session keys.
149+
150+ Uses SCAN to find all session:{user_id}:* keys. Expired sessions
151+ naturally disappear as their keys expire, avoiding stale entries.
152+ """
153+ pattern = self ._session_pattern (user_id )
159154 sessions = []
160- for sid in session_ids :
161- key = self ._session_key (user_id , sid )
162- session_json = await self ._redis .get (key )
163- if session_json :
164- session = self ._session_from_json (session_json )
165- session .messages = []
166- sessions .append (session )
155+ cursor = 0
156+
157+ while True :
158+ cursor , keys = await self ._redis .scan (
159+ cursor ,
160+ match = pattern ,
161+ count = 100 ,
162+ )
163+ for key in keys :
164+ session_json = await self ._redis .get (key )
165+ if session_json :
166+ session = self ._session_from_json (session_json )
167+ session .messages = []
168+ sessions .append (session )
169+
170+ if cursor == 0 :
171+ break
172+
167173 return sessions
168174
169175 async def append_message (
@@ -192,49 +198,54 @@ async def append_message(
192198 key = self ._session_key (user_id , session_id )
193199
194200 session_json = await self ._redis .get (key )
195- if session_json :
196- stored_session = self ._session_from_json (session_json )
197- stored_session .messages .extend (norm_message )
198-
199- # Limit the number of messages per session to prevent memory issues
200- if self ._max_messages_per_session is not None :
201- if (
202- len (stored_session .messages )
203- > self ._max_messages_per_session
204- ):
205- # Keep only the most recent messages
206- stored_session .messages = stored_session .messages [
207- - self ._max_messages_per_session :
208- ]
209-
210- await self ._redis .set (key , self ._session_to_json (stored_session ))
211- await self ._redis .sadd (self ._index_key (user_id ), session_id )
212-
213- # Set TTL for the session key if configured
214- if self ._ttl_seconds is not None :
215- await self ._redis .expire (key , self ._ttl_seconds )
216- else :
217- print (
218- f"Warning: Session { session .id } not found in storage for "
219- f"append_message." ,
201+ if session_json is None :
202+ raise RuntimeError (
203+ f"Session { session_id } not found or has expired for user "
204+ f"{ user_id } . Previous memory/state has been lost. "
205+ f"Please create a new session." ,
220206 )
221207
208+ stored_session = self ._session_from_json (session_json )
209+ stored_session .messages .extend (norm_message )
210+
211+ # Limit the number of messages per session to prevent memory issues
212+ if self ._max_messages_per_session is not None :
213+ if len (stored_session .messages ) > self ._max_messages_per_session :
214+ # Keep only the most recent messages
215+ stored_session .messages = stored_session .messages [
216+ - self ._max_messages_per_session :
217+ ]
218+
219+ await self ._redis .set (key , self ._session_to_json (stored_session ))
220+
221+ # Set TTL for the session key if configured
222+ if self ._ttl_seconds is not None :
223+ await self ._redis .expire (key , self ._ttl_seconds )
224+
222225 async def delete_user_sessions (self , user_id : str ) -> None :
223226 """
224227 Deletes all session history data for a specific user.
225228
229+ Uses SCAN to find all session keys for the user and deletes them.
230+
226231 Args:
227232 user_id (str): The ID of the user whose session history data should
228233 be deleted
229234 """
230235 if not self ._redis :
231236 raise RuntimeError ("Redis connection is not available" )
232237
233- index_key = self ._index_key (user_id )
234- session_ids = await self . _redis . smembers ( index_key )
238+ pattern = self ._session_pattern (user_id )
239+ cursor = 0
235240
236- for session_id in session_ids :
237- key = self ._session_key (user_id , session_id )
238- await self ._redis .delete (key )
241+ while True :
242+ cursor , keys = await self ._redis .scan (
243+ cursor ,
244+ match = pattern ,
245+ count = 100 ,
246+ )
247+ if keys :
248+ await self ._redis .delete (* keys )
239249
240- await self ._redis .delete (index_key )
250+ if cursor == 0 :
251+ break
0 commit comments