1- import json
2- import time
3-
41from fastapi import APIRouter , BackgroundTasks , Depends , HTTPException
52
3+ from redis_memory_server import messages
64from redis_memory_server .config import settings
75from redis_memory_server .logging import get_logger
8- from redis_memory_server .models .extraction import handle_extraction
9- from redis_memory_server .models .messages import (
6+ from redis_memory_server .models import (
107 AckResponse ,
118 GetSessionsQuery ,
12- MemoryMessage ,
13- MemoryMessagesAndContext ,
14- MemoryResponse ,
159 SearchPayload ,
1610 SearchResults ,
17- index_messages ,
18- search_messages ,
11+ SessionMemory ,
12+ SessionMemoryResponse ,
1913)
20- from redis_memory_server .models .summarization import handle_compaction
2114from redis_memory_server .utils import (
22- Keys ,
23- get_model_client ,
2415 get_openai_client ,
2516 get_redis_conn ,
2617)
3324
3425@router .get ("/sessions/" , response_model = list [str ])
3526async def list_sessions (
36- pagination : GetSessionsQuery = Depends (),
27+ options : GetSessionsQuery = Depends (),
3728):
3829 """
3930 Get a list of session IDs, with optional pagination.
4031
4132 Args:
42- pagination: Pagination parameters (page, size, namespace)
33+ options: Query parameters (page, size, namespace)
4334
4435 Returns:
4536 List of session IDs
4637 """
47- # Check page limit
48- if pagination .page > 100 :
38+ # TODO: Pydantic should validate this
39+ if options .page > 100 :
4940 raise HTTPException (status_code = 400 , detail = "Page must not exceed 100" )
5041
5142 redis = get_redis_conn ()
5243
53- # Calculate start and end indices (0-indexed start, inclusive end)
54- start = (pagination .page - 1 ) * pagination .size
55- end = pagination .page * pagination .size - 1
56-
57- # Set key based on namespace
58- sessions_key = Keys .sessions_key (namespace = pagination .namespace )
59-
60- try :
61- # Get session IDs from Redis
62- session_ids = await redis .zrange (sessions_key , start , end )
44+ return await messages .list_sessions (
45+ redis = redis ,
46+ page = options .page ,
47+ size = options .size ,
48+ namespace = options .namespace ,
49+ )
6350
64- # Convert from bytes to strings if needed
65- return [s .decode ("utf-8" ) if isinstance (s , bytes ) else s for s in session_ids ]
6651
67- except Exception as e :
68- logger .error (f"Error getting sessions: { e } " )
69- raise HTTPException (status_code = 500 , detail = "Internal server error" ) from e
70-
71-
72- @router .get ("/sessions/{session_id}/memory" , response_model = MemoryResponse )
73- async def get_session_memory (session_id : str , namespace : str | None = None ):
52+ @router .get ("/sessions/{session_id}/memory" , response_model = SessionMemoryResponse )
53+ async def get_session_memory (
54+ session_id : str ,
55+ namespace : str | None = None ,
56+ window_size : int = settings .window_size ,
57+ ):
7458 """
7559 Get memory for a session.
7660
7761 This includes stored conversation history and context.
7862
7963 Args:
8064 session_id: The session ID
65+ window_size: The number of messages to include in the response
66+ namespace: The namespace to use for the session
8167
8268 Returns:
8369 Conversation history and context
8470 """
8571 redis = get_redis_conn ()
8672
87- try :
88- # Define keys
89- sessions_key = Keys .sessions_key (namespace = namespace )
90- messages_key = Keys .messages_key (session_id , namespace = namespace )
91- context_key = Keys .context_key (session_id , namespace = namespace )
92- token_count_key = Keys .token_count_key (session_id , namespace = namespace )
93-
94- # TODO: Use a hash
95- session_exists = await redis .zscore (sessions_key , session_id )
96- if not session_exists :
97- raise HTTPException (status_code = 404 , detail = "Session not found" )
98-
99- # Get data from Redis in a pipeline
100- pipe = redis .pipeline ()
101- # TODO: Make window size configurable via API parameter
102- pipe .lrange (messages_key , 0 , settings .window_size - 1 ) # Get messages
103- pipe .mget (context_key , token_count_key ) # Get context and token count
104- results = await pipe .execute ()
105-
106- # Extract results
107- messages_raw = results [0 ]
108- context_and_tokens = results [1 ]
109-
110- # Parse messages
111- memory_messages = []
112- for msg_raw in messages_raw :
113- # Decode if needed
114- if isinstance (msg_raw , bytes ):
115- msg_raw = msg_raw .decode ("utf-8" )
116-
117- # Parse JSON
118- msg_dict = json .loads (msg_raw )
119-
120- # Convert comma-separated strings back to lists for topics and entities
121- if "topics" in msg_dict :
122- msg_dict ["topics" ] = (
123- msg_dict ["topics" ].split ("," ) if msg_dict ["topics" ] else []
124- )
125- if "entities" in msg_dict :
126- msg_dict ["entities" ] = (
127- msg_dict ["entities" ].split ("," ) if msg_dict ["entities" ] else []
128- )
129-
130- memory_messages .append (MemoryMessage (** msg_dict ))
131-
132- # Extract context and tokens
133- context = None
134- tokens = None
135-
136- if context_and_tokens [0 ]:
137- context_bytes = context_and_tokens [0 ]
138- context = (
139- context_bytes .decode ("utf-8" )
140- if isinstance (context_bytes , bytes )
141- else context_bytes
142- )
143-
144- if context_and_tokens [1 ]:
145- tokens_bytes = context_and_tokens [1 ]
146- tokens_str = (
147- tokens_bytes .decode ("utf-8" )
148- if isinstance (tokens_bytes , bytes )
149- else tokens_bytes
150- )
151- tokens = int (tokens_str )
152-
153- # Build response
154- return MemoryResponse (
155- messages = memory_messages ,
156- context = context ,
157- tokens = tokens ,
158- )
159-
160- except HTTPException as e :
161- raise e
162- except Exception as e :
163- logger .error (f"Error getting memory for session { session_id } : { e } " )
164- raise HTTPException (status_code = 500 , detail = "Internal server error" ) from e
165-
166-
167- @router .post ("/sessions/{session_id}/memory" , response_model = AckResponse )
168- async def post_memory (
73+ session = await messages .get_session_memory (
74+ redis = redis ,
75+ session_id = session_id ,
76+ window_size = window_size ,
77+ namespace = namespace ,
78+ )
79+ if not session :
80+ raise HTTPException (status_code = 404 , detail = "Session not found" )
81+
82+ return session
83+
84+
85+ @router .put ("/sessions/{session_id}/memory" , response_model = AckResponse )
86+ async def put_session_memory (
16987 session_id : str ,
170- memory_messages : MemoryMessagesAndContext ,
88+ memory : SessionMemory ,
17189 background_tasks : BackgroundTasks ,
172- namespace : str | None = None ,
17390):
17491 """
175- Add messages to a session's memory
92+ Set session memory. Replaces existing session memory.
17693
17794 Args:
17895 session_id: The session ID
179- memory_messages: Messages and optional context to add
180- namespace: Optional namespace for the session
96+ memory: Messages and context to save
18197
18298 Returns:
18399 Acknowledgement response
184100 """
185101 redis = get_redis_conn ()
186102
187- try :
188- # Define keys
189- messages_key = Keys .messages_key (session_id )
190- context_key = Keys .context_key (session_id )
191- sessions_key = f"sessions:{ namespace } " if namespace else "sessions"
192-
193- if memory_messages .context is not None :
194- await redis .set (context_key , memory_messages .context )
195-
196- current_time = int (time .time ())
197- await redis .zadd (sessions_key , {session_id : current_time })
198-
199- model_client = await get_model_client (settings .generation_model )
200- messages_json = []
201-
202- # Process messages for topic/entity extraction
203- # TODO: Use a distributed background task
204- for msg in memory_messages .messages :
205- # Handle extraction in background for each message
206- msg = await handle_extraction (msg )
207- msg_dict = msg .model_dump ()
208- # Convert lists to comma-separated strings for TAG fields
209- msg_dict ["topics" ] = "," .join (msg .topics ) if msg .topics else ""
210- msg_dict ["entities" ] = "," .join (msg .entities ) if msg .entities else ""
211- messages_json .append (json .dumps (msg_dict ))
212-
213- # Add messages to list
214- await redis .rpush (messages_key , * messages_json ) # type: ignore
215-
216- # Check if window size is exceeded
217- current_size = await redis .llen (messages_key ) # type: ignore
218- if current_size > settings .window_size :
219- # Handle compaction in background
220- background_tasks .add_task (
221- handle_compaction ,
222- session_id ,
223- settings .generation_model ,
224- settings .window_size ,
225- model_client ,
226- redis ,
227- )
228-
229- # If long-term memory is enabled, index messages
230- # TODO: Use a distributed background task
231- if settings .long_term_memory :
232- embedding_client = await get_openai_client ()
233- background_tasks .add_task (
234- index_messages ,
235- memory_messages .messages ,
236- session_id ,
237- embedding_client ,
238- redis ,
239- namespace ,
240- )
241-
242- return AckResponse (status = "ok" )
243- except Exception as e :
244- logger .error (f"Error adding messages for session { session_id } : { e } " )
245- raise HTTPException (status_code = 500 , detail = "Internal server error" ) from e
103+ await messages .set_session_memory (
104+ redis = redis ,
105+ session_id = session_id ,
106+ memory = memory ,
107+ background_tasks = background_tasks ,
108+ )
109+ return AckResponse (status = "ok" )
246110
247111
248112@router .delete ("/sessions/{session_id}/memory" , response_model = AckResponse )
249- async def delete_memory (
113+ async def delete_session_memory (
250114 session_id : str ,
251115 namespace : str | None = None ,
252116):
@@ -261,38 +125,23 @@ async def delete_memory(
261125 Acknowledgement response
262126 """
263127 redis = get_redis_conn ()
264- try :
265- # Define keys
266- messages_key = Keys .messages_key (session_id )
267- context_key = Keys .context_key (session_id )
268- token_count_key = Keys .token_count_key (session_id )
269- sessions_key = f"sessions:{ namespace } " if namespace else "sessions"
270-
271- # Create pipeline for deletion
272- pipe = redis .pipeline ()
273- pipe .delete (messages_key , context_key , token_count_key )
274- pipe .zrem (sessions_key , session_id )
275- await pipe .execute ()
276-
277- return AckResponse (status = "ok" )
278- except Exception as e :
279- logger .error (f"Error deleting memory for session { session_id } : { e } " )
280- raise HTTPException (status_code = 500 , detail = "Internal server error" ) from e
281-
282-
283- @router .post ("/sessions/{session_id}/search" , response_model = SearchResults )
284- async def search_session_messages (
285- session_id : str ,
286- payload : SearchPayload ,
287- namespace : str | None = None ,
288- ):
128+ await messages .delete_session_memory (
129+ redis = redis ,
130+ session_id = session_id ,
131+ namespace = namespace ,
132+ )
133+ return AckResponse (status = "ok" )
134+
135+
136+ @router .post ("/messages/search" , response_model = SearchResults )
137+ async def messages_search (payload : SearchPayload ):
289138 """
290- Run a semantic search on the messages in a session
139+ Run a semantic search on messages
140+
141+ TODO: Infer topics for `text`
291142
292143 Args:
293- session_id: The session ID
294- payload: Search payload with text to search for
295- namespace: Optional namespace for the session
144+ payload: Search payload
296145
297146 Returns:
298147 List of search results
@@ -305,14 +154,8 @@ async def search_session_messages(
305154 # For embeddings, we always use OpenAI models since Anthropic doesn't support embeddings
306155 client = await get_openai_client ()
307156
308- try :
309- return await search_messages (
310- payload .text ,
311- client ,
312- redis ,
313- session_id = session_id ,
314- namespace = namespace ,
315- )
316- except Exception as e :
317- logger .error (f"Error in retrieval API: { e } " )
318- raise HTTPException (status_code = 500 , detail = "Internal server error" ) from e
157+ return await messages .search_messages (
158+ client = client ,
159+ redis_conn = redis ,
160+ ** payload .model_dump (exclude_none = True ),
161+ )
0 commit comments