Skip to content

Commit 1a9f1b7

Browse files
committed
More tweaks
1 parent 0aba989 commit 1a9f1b7

File tree

11 files changed

+91
-42
lines changed

11 files changed

+91
-42
lines changed

README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ To start the API using Docker Compose, follow these steps:
109109
4. Build and start the containers by running:
110110
docker-compose up --build
111111

112-
5. Once the containers are up, the API will be available at http://localhost:8000. You can also access the interactive API documentation at http://localhost:8000/docs.
112+
5. Once the containers are up, the REST API will be available at http://localhost:8000. You can also access the interactive API documentation at http://localhost:8000/docs. The MCP server will be available at http://localhost:9000/sse.
113113

114114
6. To stop the containers, press Ctrl+C in the terminal and then run:
115115
docker-compose down
@@ -127,11 +127,12 @@ You can configure the service using environment variables:
127127
| `ANTHROPIC_API_KEY` | API key for Anthropic | - |
128128
| `GENERATION_MODEL` | Model for text generation | `gpt-4o-mini` |
129129
| `EMBEDDING_MODEL` | Model for text embeddings | `text-embedding-3-small` |
130-
| `PORT` | Server port | `8000` |
130+
| `PORT` | REST API server port | `8000` |
131131
| `TOPIC_MODEL` | BERTopic model for topic extraction | `MaartenGr/BERTopic_Wikipedia` |
132132
| `NER_MODEL` | BERT model for NER | `dbmdz/bert-large-cased-finetuned-conll03-english` |
133133
| `ENABLE_TOPIC_EXTRACTION` | Enable/disable topic extraction | `True` |
134134
| `ENABLE_NER` | Enable/disable named entity recognition | `True` |
135+
| `MCP_PORT` | MCP server port |9000|
135136

136137

137138
## Development
@@ -145,16 +146,25 @@ pip install -e ".[dev]"
145146

146147
2. Set up environment variables (see Configuration section)
147148

148-
3. Run the server:
149+
3. Run the API server:
149150
```bash
150151
python -m redis_memory_server.main
151152
```
152153

154+
4. In a separate terminal, run the MCP server (use either the "stdio" or "sse" options to set the running mode):
155+
```bash
156+
python -m redis_memory_server.mcp [stdio|sse]
157+
```
158+
153159
### Running Tests
154160
```bash
155161
python -m pytest
156162
```
157163

164+
## Known Issues
165+
- The MCP server from the Python MCP SDK often refuses to shut down with Control-C if it's connected to a client
166+
- All background tasks run as async coroutines in the same process as the REST API server, using Starlette's `BackgroundTask`
167+
158168
### Contributing
159169
1. Fork the repository
160170
2. Create a feature branch

docker-compose.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ services:
2828
timeout: 10s
2929
retries: 3
3030

31+
mcp:
32+
build:
33+
context: .
34+
dockerfile: Dockerfile
35+
environment:
36+
- REDIS_URL=redis://redis:6379
37+
- PORT=9000
38+
# Add your API keys here or use a .env file
39+
- OPENAI_API_KEY=${OPENAI_API_KEY}
40+
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
41+
# Optional configurations with defaults
42+
- LONG_TERM_MEMORY=True
43+
- WINDOW_SIZE=20
44+
- GENERATION_MODEL=gpt-4o-mini
45+
- EMBEDDING_MODEL=text-embedding-3-small
46+
- ENABLE_TOPIC_EXTRACTION=True
47+
- ENABLE_NER=True
48+
ports:
49+
- "9000:9000"
50+
depends_on:
51+
- redis
52+
command: ["python", "-m", "redis_memory_server.mcp", "sse"]
53+
3154
redis:
3255
image: redis/redis-stack:latest
3356
ports:

redis_memory_server/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ async def delete_session_memory(
134134

135135

136136
@router.post("/long-term-memory", response_model=AckResponse)
137-
async def create_long_term_memory(payload: CreateLongTermMemoryPayload):
137+
async def create_long_term_memory(
138+
payload: CreateLongTermMemoryPayload, background_tasks: BackgroundTasks
139+
):
138140
"""
139141
Create a long-term memory
140142
@@ -152,6 +154,7 @@ async def create_long_term_memory(payload: CreateLongTermMemoryPayload):
152154
await long_term_memory.index_long_term_memories(
153155
redis=redis,
154156
memories=payload.memories,
157+
background_tasks=background_tasks,
155158
)
156159
return AckResponse(status="ok")
157160

redis_memory_server/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Settings(BaseSettings):
1616
generation_model: str = "gpt-4o-mini"
1717
embedding_model: str = "text-embedding-3-small"
1818
port: int = 8000
19+
mcp_port: int = 9000
1920

2021
# Topic and NER model settings
2122
topic_model: str = "MaartenGr/BERTopic_Wikipedia"

redis_memory_server/long_term_memory.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def extract_memory_structure(
3636
entities_joined = ",".join(entities) if entities else ""
3737

3838
await redis.hset(
39-
Keys.memory_key(_id, ""),
39+
Keys.memory_key(_id, namespace),
4040
mapping={
4141
"topics": topics_joined,
4242
"entities": entities_joined,
@@ -62,7 +62,6 @@ async def index_long_term_memories(
6262
id_ = memory.id_ if memory.id_ else nanoid.generate()
6363
key = Keys.memory_key(id_, memory.namespace)
6464
vector = embedding.tobytes()
65-
id_ = memory.id_ if memory.id_ else nanoid.generate()
6665

6766
await pipe.hset( # type: ignore
6867
key,
@@ -190,8 +189,12 @@ async def search_long_term_memories(
190189
user_id=doc.user_id,
191190
session_id=doc.session_id,
192191
namespace=doc.namespace,
193-
topics=doc.topics.split(",") if doc.topics else [],
194-
entities=doc.entities.split(",") if doc.entities else [],
192+
topics=doc.get("topics", "").split(",")
193+
if doc.get("topics")
194+
else [],
195+
entities=doc.get("entities", "").split(",")
196+
if doc.get("entities")
197+
else [],
195198
)
196199
)
197200

redis_memory_server/main.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from redis_memory_server.healthcheck import router as health_router
1111
from redis_memory_server.llms import MODEL_CONFIGS, ModelProvider
1212
from redis_memory_server.logging import configure_logging, get_logger
13-
from redis_memory_server.mcp import mcp_app
1413
from redis_memory_server.utils import ensure_redisearch_index, get_redis_conn
1514

1615

@@ -125,10 +124,6 @@ async def lifespan(app: FastAPI):
125124
app.include_router(memory_router)
126125

127126

128-
# Mount MCP server
129-
app.mount("/", mcp_app.sse_app())
130-
131-
132127
def on_start_logger(port: int):
133128
"""Log startup information"""
134129
print("\n-----------------------------------")

redis_memory_server/mcp.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import sys
23

34
from fastapi import HTTPException
45
from mcp.server.fastmcp import FastMCP
@@ -10,6 +11,7 @@
1011
get_session_memory as core_get_session_memory,
1112
search_long_term_memory as core_search_long_term_memory,
1213
)
14+
from redis_memory_server.config import settings
1315
from redis_memory_server.models import (
1416
AckResponse,
1517
CreateLongTermMemoryPayload,
@@ -20,7 +22,7 @@
2022

2123

2224
logger = logging.getLogger(__name__)
23-
mcp_app = FastMCP("Redis Agent Memory Server")
25+
mcp_app = FastMCP("Redis Agent Memory Server", port=settings.mcp_port)
2426

2527

2628
@mcp_app.tool()
@@ -146,3 +148,10 @@ async def memory_prompt(
146148
)
147149

148150
return messages
151+
152+
153+
if __name__ == "__main__":
154+
if len(sys.argv) > 1 and sys.argv[1] == "sse":
155+
mcp_app.run(transport="sse")
156+
else:
157+
mcp_app.run(transport="stdio")

redis_memory_server/messages.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44

55
from fastapi import BackgroundTasks
6+
from redis import WatchError
67
from redis.asyncio import Redis
78

89
from redis_memory_server.config import settings
@@ -116,16 +117,25 @@ async def set_session_memory(
116117
messages_key = Keys.messages_key(session_id, namespace=memory.namespace)
117118
metadata_key = Keys.metadata_key(session_id, namespace=memory.namespace)
118119
messages_json = [json.dumps(msg.model_dump()) for msg in memory.messages]
119-
120120
metadata = memory.model_dump(
121121
exclude_none=True,
122122
exclude={"messages"},
123123
)
124124

125-
current_time = int(time.time())
126-
await redis.zadd(sessions_key, {session_id: current_time})
127-
await redis.rpush(messages_key, *messages_json) # type: ignore
128-
await redis.hset(metadata_key, mapping=metadata) # type: ignore
125+
async with redis.pipeline(transaction=True) as pipe:
126+
await pipe.watch(messages_key, metadata_key)
127+
pipe.multi()
128+
129+
while True:
130+
try:
131+
current_time = int(time.time())
132+
pipe.zadd(sessions_key, {session_id: current_time})
133+
pipe.rpush(messages_key, *messages_json) # type: ignore
134+
pipe.hset(metadata_key, mapping=metadata) # type: ignore
135+
await pipe.execute()
136+
except WatchError:
137+
continue
138+
break
129139

130140
# Check if window size is exceeded
131141
current_size = await redis.llen(messages_key) # type: ignore
@@ -155,6 +165,7 @@ async def set_session_memory(
155165
)
156166
for msg in memory.messages
157167
],
168+
background_tasks,
158169
)
159170

160171

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ async def session(use_test_redis_connection, async_redis_client):
118118
namespace="test-namespace",
119119
),
120120
],
121+
background_tasks=BackgroundTasks(),
121122
)
122123

123124
# Add messages to session memory

tests/test_long_term_memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import nanoid
66
import numpy as np
77
import pytest
8+
from fastapi import BackgroundTasks
89
from redis.commands.search.document import Document
910

1011
from redis_memory_server.long_term_memory import (
@@ -40,6 +41,7 @@ async def test_index_messages(
4041
await index_long_term_memories(
4142
mock_async_redis_client,
4243
long_term_memories,
44+
background_tasks=BackgroundTasks(),
4345
)
4446

4547
# Check that create_embedding was called with the right arguments
@@ -169,6 +171,7 @@ async def test_search_messages(self, async_redis_client):
169171
await index_long_term_memories(
170172
async_redis_client,
171173
long_term_memories,
174+
background_tasks=BackgroundTasks(),
172175
)
173176

174177
results = await search_long_term_memories(
@@ -194,6 +197,7 @@ async def test_search_messages_with_distance_threshold(self, async_redis_client)
194197
await index_long_term_memories(
195198
async_redis_client,
196199
long_term_memories,
200+
background_tasks=BackgroundTasks(),
197201
)
198202

199203
results = await search_long_term_memories(

0 commit comments

Comments
 (0)