Skip to content

Commit dec68de

Browse files
committed
Add MCP server, fix bugs, reorg modules
1 parent c8e1b8b commit dec68de

20 files changed

+919
-695
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"transformers>=4.30.0",
2727
"numba>=0.60.0",
2828
"nanoid>=2.0.0",
29+
"mcp>=1.6.0",
2930
]
3031

3132
[project.optional-dependencies]
Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
55

66
from redis_memory_server.config import settings
7-
from redis_memory_server.extraction import handle_extraction
87
from redis_memory_server.logging import get_logger
9-
from redis_memory_server.long_term_memory import index_messages
10-
from redis_memory_server.models import (
8+
from redis_memory_server.models.extraction import handle_extraction
9+
from redis_memory_server.models.messages import (
1110
AckResponse,
1211
GetSessionsQuery,
1312
MemoryMessage,
1413
MemoryMessagesAndContext,
1514
MemoryResponse,
15+
SearchPayload,
16+
SearchResults,
17+
index_messages,
18+
search_messages,
1619
)
17-
from redis_memory_server.summarization import handle_compaction
20+
from redis_memory_server.models.summarization import handle_compaction
1821
from redis_memory_server.utils import (
1922
Keys,
2023
get_model_client,
@@ -29,11 +32,11 @@
2932

3033

3134
@router.get("/sessions/", response_model=list[str])
32-
async def get_sessions(
35+
async def list_sessions(
3336
pagination: GetSessionsQuery = Depends(),
3437
):
3538
"""
36-
Get a list of session IDs, with optional pagination
39+
Get a list of session IDs, with optional pagination.
3740
3841
Args:
3942
pagination: Pagination parameters (page, size, namespace)
@@ -52,9 +55,7 @@ async def get_sessions(
5255
end = pagination.page * pagination.size - 1
5356

5457
# Set key based on namespace
55-
sessions_key = (
56-
f"sessions:{pagination.namespace}" if pagination.namespace else "sessions"
57-
)
58+
sessions_key = Keys.sessions_key(namespace=pagination.namespace)
5859

5960
try:
6061
# Get session IDs from Redis
@@ -69,26 +70,35 @@ async def get_sessions(
6970

7071

7172
@router.get("/sessions/{session_id}/memory", response_model=MemoryResponse)
72-
async def get_memory(session_id: str):
73+
async def get_session_memory(session_id: str, namespace: str | None = None):
7374
"""
74-
Get memory for a session
75+
Get memory for a session.
76+
77+
This includes stored conversation history and context.
7578
7679
Args:
7780
session_id: The session ID
7881
7982
Returns:
80-
Memory response with messages and context
83+
Conversation history and context
8184
"""
8285
redis = get_redis_conn()
8386

8487
try:
8588
# Define keys
86-
messages_key = Keys.messages_key(session_id)
87-
context_key = Keys.context_key(session_id)
88-
token_count_key = Keys.token_count_key(session_id)
89+
sessions_key = Keys.sessions_key(namespace=namespace)
90+
messages_key = Keys.messages_key(session_id, namespace=namespace)
91+
context_key = Keys.context_key(session_id, namespace=namespace)
92+
token_count_key = Keys.token_count_key(session_id, namespace=namespace)
93+
94+
# TODO: Use a hash
95+
session_exists = await redis.zscore(sessions_key, session_id)
96+
if not session_exists:
97+
raise HTTPException(status_code=404, detail="Session not found")
8998

9099
# Get data from Redis in a pipeline
91100
pipe = redis.pipeline()
101+
# TODO: Make window size configurable via API parameter
92102
pipe.lrange(messages_key, 0, settings.window_size - 1) # Get messages
93103
pipe.mget(context_key, token_count_key) # Get context and token count
94104
results = await pipe.execute()
@@ -147,6 +157,8 @@ async def get_memory(session_id: str):
147157
tokens=tokens,
148158
)
149159

160+
except HTTPException as e:
161+
raise e
150162
except Exception as e:
151163
logger.error(f"Error getting memory for session {session_id}: {e}")
152164
raise HTTPException(status_code=500, detail="Internal server error") from e
@@ -178,20 +190,17 @@ async def post_memory(
178190
context_key = Keys.context_key(session_id)
179191
sessions_key = f"sessions:{namespace}" if namespace else "sessions"
180192

181-
# Check if new context is provided
182193
if memory_messages.context is not None:
183194
await redis.set(context_key, memory_messages.context)
184195

185-
# Add session to sessions set with timestamp
186196
current_time = int(time.time())
187197
await redis.zadd(sessions_key, {session_id: current_time})
188198

189-
# Get model client for extraction
190199
model_client = await get_model_client(settings.generation_model)
191-
192200
messages_json = []
193201

194202
# Process messages for topic/entity extraction
203+
# TODO: Use a distributed background task
195204
for msg in memory_messages.messages:
196205
# Handle extraction in background for each message
197206
msg = await handle_extraction(msg)
@@ -202,7 +211,7 @@ async def post_memory(
202211
messages_json.append(json.dumps(msg_dict))
203212

204213
# Add messages to list
205-
await redis.lpush(messages_key, *messages_json) # type: ignore
214+
await redis.rpush(messages_key, *messages_json) # type: ignore
206215

207216
# Check if window size is exceeded
208217
current_size = await redis.llen(messages_key) # type: ignore
@@ -218,6 +227,7 @@ async def post_memory(
218227
)
219228

220229
# If long-term memory is enabled, index messages
230+
# TODO: Use a distributed background task
221231
if settings.long_term_memory:
222232
embedding_client = await get_openai_client()
223233
background_tasks.add_task(
@@ -226,6 +236,7 @@ async def post_memory(
226236
session_id,
227237
embedding_client,
228238
redis,
239+
namespace,
229240
)
230241

231242
return AckResponse(status="ok")
@@ -267,3 +278,41 @@ async def delete_memory(
267278
except Exception as e:
268279
logger.error(f"Error deleting memory for session {session_id}: {e}")
269280
raise HTTPException(status_code=500, detail="Internal server error") from e
281+
282+
283+
@router.post("/sessions/{session_id}/search", response_model=SearchResults)
284+
async def search_session_messages(
285+
session_id: str,
286+
payload: SearchPayload,
287+
namespace: str | None = None,
288+
):
289+
"""
290+
Run a semantic search on the messages in a session
291+
292+
Args:
293+
session_id: The session ID
294+
payload: Search payload with text to search for
295+
namespace: Optional namespace for the session
296+
297+
Returns:
298+
List of search results
299+
"""
300+
redis = get_redis_conn()
301+
302+
if not settings.long_term_memory:
303+
raise HTTPException(status_code=400, detail="Long term memory is disabled")
304+
305+
# For embeddings, we always use OpenAI models since Anthropic doesn't support embeddings
306+
client = await get_openai_client()
307+
308+
try:
309+
return await search_messages(
310+
payload.text,
311+
client,
312+
redis,
313+
session_id=session_id,
314+
namespace=namespace,
315+
)
316+
except Exception as e:
317+
logger.error(f"Error in retrieval API: {e}")
318+
raise HTTPException(status_code=500, detail="Internal server error") from e

redis_memory_server/healthcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from fastapi import APIRouter
44

5-
from redis_memory_server.models import HealthCheckResponse
5+
from redis_memory_server.models.messages import HealthCheckResponse
66

77

88
router = APIRouter()
Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -7,88 +7,12 @@
77
import anthropic
88
import numpy as np
99
from openai import AsyncOpenAI
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel
1111

1212

13-
# Setup logging
1413
logger = logging.getLogger(__name__)
1514

1615

17-
class MemoryMessage(BaseModel):
18-
"""A message in the memory system"""
19-
20-
role: str
21-
content: str
22-
topics: list[str] = Field(
23-
default_factory=list, description="List of topics associated with this message"
24-
)
25-
entities: list[str] = Field(
26-
default_factory=list, description="List of entities mentioned in this message"
27-
)
28-
29-
30-
class MemoryMessagesAndContext(BaseModel):
31-
"""Request payload for adding messages to memory"""
32-
33-
messages: list[MemoryMessage]
34-
context: str | None = None
35-
36-
37-
class MemoryResponse(BaseModel):
38-
"""Response containing messages and context"""
39-
40-
messages: list[MemoryMessage]
41-
context: str | None = None
42-
tokens: int | None = None
43-
44-
45-
class SearchPayload(BaseModel):
46-
"""Payload for semantic search"""
47-
48-
text: str
49-
50-
51-
class HealthCheckResponse(BaseModel):
52-
"""Response for health check endpoint"""
53-
54-
now: int
55-
56-
57-
class AckResponse(BaseModel):
58-
"""Generic acknowledgement response"""
59-
60-
status: str
61-
62-
63-
class RedisearchResult(BaseModel):
64-
"""Result from a redisearch query"""
65-
66-
role: str
67-
content: str
68-
dist: float
69-
70-
71-
class SearchResults(BaseModel):
72-
"""Results from a redisearch query"""
73-
74-
docs: list[RedisearchResult]
75-
total: int
76-
77-
78-
class NamespaceQuery(BaseModel):
79-
"""Query parameters for namespace"""
80-
81-
namespace: str | None = None
82-
83-
84-
class GetSessionsQuery(BaseModel):
85-
"""Query parameters for getting sessions"""
86-
87-
page: int = Field(default=1)
88-
size: int = Field(default=20)
89-
namespace: str | None = None
90-
91-
9216
class ModelProvider(str, Enum):
9317
"""Type of model provider"""
9418

redis_memory_server/main.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import os
22

33
import uvicorn
4-
from fastapi import FastAPI
4+
from fastapi import BackgroundTasks, FastAPI
55

66
from redis_memory_server import utils
7+
from redis_memory_server.api import router as memory_router
78
from redis_memory_server.config import settings
89
from redis_memory_server.healthcheck import router as health_router
10+
from redis_memory_server.llms import MODEL_CONFIGS, ModelProvider
911
from redis_memory_server.logging import configure_logging, get_logger
10-
from redis_memory_server.memory import router as memory_router
11-
from redis_memory_server.models import MODEL_CONFIGS, ModelProvider
12-
from redis_memory_server.retrieval import router as retrieval_router
12+
from redis_memory_server.mcp import mcp_app
1313
from redis_memory_server.utils import ensure_redisearch_index, get_redis_conn
1414

1515

@@ -124,7 +124,19 @@ async def shutdown_event():
124124

125125
app.include_router(health_router)
126126
app.include_router(memory_router)
127-
app.include_router(retrieval_router)
127+
128+
129+
# Set up MCP routes
130+
@app.middleware("http")
131+
async def mcp_middleware(request, call_next):
132+
"""Middleware to inject BackgroundTasks into MCP handler"""
133+
background_tasks = BackgroundTasks()
134+
request.state.background_tasks = background_tasks
135+
return await call_next(request)
136+
137+
138+
# Mount MCP server
139+
app.mount("/mcp", mcp_app.sse_app())
128140

129141

130142
def on_start_logger(port: int):

0 commit comments

Comments
 (0)