Skip to content

Commit d209a17

Browse files
committed
Move into package
1 parent f0c26ce commit d209a17

20 files changed

+637
-70
lines changed

pyproject.toml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,54 @@
1+
[build-system]
2+
requires = ["hatchling"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "redis-memory-server"
7+
version = "0.1.0"
8+
description = "A memory system for conversational AI using Redis"
9+
readme = "README.md"
10+
requires-python = ">=3.10"
11+
license = "MIT"
12+
authors = [
13+
{ name = "Andrew Brookins", email = "[email protected]" },
14+
]
15+
dependencies = [
16+
"fastapi>=0.109.0",
17+
"uvicorn>=0.27.0",
18+
"redis>=5.0.1",
19+
"openai>=1.0.0",
20+
"anthropic>=0.18.1",
21+
"python-dotenv>=1.0.0",
22+
"structlog>=24.1.0",
23+
"tiktoken>=0.5.2",
24+
"numpy>=1.26.0",
25+
]
26+
27+
[project.optional-dependencies]
28+
dev = [
29+
"pytest>=7.4.0",
30+
"pytest-asyncio>=0.23.0",
31+
"pytest-xdist>=3.5.0",
32+
"black>=24.1.0",
33+
"ruff>=0.2.0",
34+
"testcontainers>=3.7.0",
35+
"pre-commit>=3.6.0",
36+
]
37+
38+
[tool.hatch.build.targets.wheel]
39+
packages = ["redis_memory_server"]
40+
41+
[tool.pytest.ini_options]
42+
addopts = "-v"
43+
testpaths = ["tests"]
44+
python_files = ["test_*.py"]
45+
asyncio_mode = "auto"
46+
47+
[tool.black]
48+
line-length = 88
49+
target-version = ['py310']
50+
include = '\.pyi?$'
51+
152
[tool.ruff]
253
# Exclude a variety of commonly ignored directories
354
exclude = [
@@ -49,3 +100,6 @@ ban-relative-imports = "all"
49100
quote-style = "double"
50101
# Use spaces for indentation
51102
indent-style = "space"
103+
104+
[tool.ruff.lint.per-file-ignores]
105+
"__init__.py" = ["F401"]

redis_memory_server/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Redis Memory Server - A memory system for conversational AI."""
2+
3+
__version__ = "0.1.0"

config.py renamed to redis_memory_server/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import os
22

3+
from dotenv import load_dotenv
34
from pydantic_settings import BaseSettings
45

56

7+
load_dotenv()
8+
9+
610
class Settings(BaseSettings):
711
redis_url: str = "redis://localhost:6379"
812
long_term_memory: bool = True

redis_memory_server/extraction.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
import logging
3+
4+
from redis_memory_server.models import (
5+
AnthropicClientWrapper,
6+
MemoryMessage,
7+
OpenAIClientWrapper,
8+
)
9+
10+
11+
logger = logging.getLogger(__name__)
12+
13+
EXTRACTION_PROMPT = """Analyze the following message and extract:
14+
1. Key topics (as single words or short phrases)
15+
2. Named entities (people, places, organizations, etc.)
16+
17+
Message: {message}
18+
19+
Respond in JSON format:
20+
{{
21+
"topics": ["topic1", "topic2", ...],
22+
"entities": ["entity1", "entity2", ...]
23+
}}
24+
25+
Keep topics and entities concise and relevant."""
26+
27+
28+
async def extract_topics_and_entities(
29+
message: str,
30+
model_client: OpenAIClientWrapper | AnthropicClientWrapper,
31+
) -> tuple[list[str], list[str]]:
32+
"""
33+
Extract topics and entities from a message using the LLM.
34+
35+
Args:
36+
message: The message to analyze
37+
model_client: The LLM client to use
38+
39+
Returns:
40+
Tuple of (topics, entities) lists
41+
"""
42+
try:
43+
# Get LLM response
44+
response = await model_client.create_chat_completion(
45+
"gpt-4o-mini", # TODO: Make configurable
46+
EXTRACTION_PROMPT.format(message=message),
47+
)
48+
49+
# Parse JSON response from content field
50+
content = response.choices[0]["message"]["content"].strip()
51+
result = json.loads(content)
52+
53+
# Extract and validate topics and entities
54+
topics = result.get("topics", [])
55+
entities = result.get("entities", [])
56+
57+
# Ensure we have lists
58+
if not isinstance(topics, list) or not isinstance(entities, list):
59+
logger.error("Invalid extraction response format")
60+
return [], []
61+
62+
return topics, entities
63+
64+
except Exception as e:
65+
logger.error(f"Error in topic/entity extraction: {e}")
66+
return [], []
67+
68+
69+
async def handle_extraction(
70+
message: MemoryMessage,
71+
model_client: OpenAIClientWrapper | AnthropicClientWrapper,
72+
) -> MemoryMessage:
73+
"""
74+
Handle topic and entity extraction for a message.
75+
76+
Args:
77+
message: The message to process
78+
model_client: The LLM client to use
79+
80+
Returns:
81+
Updated message with extracted topics and entities
82+
"""
83+
# Skip if message already has both topics and entities
84+
if message.topics and message.entities:
85+
return message
86+
87+
# Extract topics and entities
88+
extracted_topics, extracted_entities = await extract_topics_and_entities(
89+
message.content, model_client
90+
)
91+
92+
# Merge with existing topics and entities
93+
message.topics = (
94+
list(set(message.topics + extracted_topics))
95+
if message.topics
96+
else extracted_topics
97+
)
98+
message.entities = (
99+
list(set(message.entities + extracted_entities))
100+
if message.entities
101+
else extracted_entities
102+
)
103+
104+
return message
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from fastapi import APIRouter
44

5-
from models import HealthCheckResponse
5+
from redis_memory_server.models import HealthCheckResponse
66

77

88
router = APIRouter()
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from redis.asyncio import Redis
55
from redis.commands.search.query import Query
66

7-
from models import (
7+
from redis_memory_server.models import (
88
MemoryMessage,
99
OpenAIClientWrapper,
1010
RedisearchResult,
1111
SearchResults,
1212
)
13-
from utils import REDIS_INDEX_NAME, Keys
13+
from redis_memory_server.utils import REDIS_INDEX_NAME, Keys
1414

1515

1616
logger = logging.getLogger(__name__)

main.py renamed to redis_memory_server/main.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,15 @@
22

33
import structlog
44
import uvicorn
5-
from dotenv import load_dotenv
65
from fastapi import FastAPI
76

8-
import utils
9-
from models import MODEL_CONFIGS, ModelProvider
10-
11-
12-
load_dotenv()
13-
14-
from config import settings
15-
from healthcheck import router as health_router
16-
from memory import router as memory_router
17-
from retrieval import router as retrieval_router
18-
from utils import ensure_redisearch_index, get_redis_conn
7+
from redis_memory_server import utils
8+
from redis_memory_server.config import settings
9+
from redis_memory_server.healthcheck import router as health_router
10+
from redis_memory_server.memory import router as memory_router
11+
from redis_memory_server.models import MODEL_CONFIGS, ModelProvider
12+
from redis_memory_server.retrieval import router as retrieval_router
13+
from redis_memory_server.utils import ensure_redisearch_index, get_redis_conn
1914

2015

2116
# Configure logging

memory.py renamed to redis_memory_server/memory.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,23 @@
44

55
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
66

7-
from config import settings
8-
from long_term_memory import index_messages
9-
from models import (
7+
from redis_memory_server.config import settings
8+
from redis_memory_server.extraction import handle_extraction
9+
from redis_memory_server.long_term_memory import index_messages
10+
from redis_memory_server.models import (
1011
AckResponse,
1112
GetSessionsQuery,
1213
MemoryMessage,
1314
MemoryMessagesAndContext,
1415
MemoryResponse,
1516
)
16-
from summarization import handle_compaction
17-
from utils import Keys, get_model_client, get_openai_client, get_redis_conn
17+
from redis_memory_server.summarization import handle_compaction
18+
from redis_memory_server.utils import (
19+
Keys,
20+
get_model_client,
21+
get_openai_client,
22+
get_redis_conn,
23+
)
1824

1925

2026
logger = logging.getLogger(__name__)
@@ -180,9 +186,23 @@ async def post_memory(
180186
current_time = int(time.time())
181187
await redis.zadd(sessions_key, {session_id: current_time})
182188

189+
# Get model client for extraction
190+
model_client = await get_model_client(settings.generation_model)
191+
192+
# Process messages for topic/entity extraction
193+
processed_messages = []
194+
for msg in memory_messages.messages:
195+
# Handle extraction in background for each message
196+
background_tasks.add_task(
197+
handle_extraction,
198+
msg,
199+
model_client,
200+
)
201+
processed_messages.append(msg)
202+
183203
# Convert messages to JSON, handling topics and entities
184204
messages_json = []
185-
for msg in memory_messages.messages:
205+
for msg in processed_messages:
186206
msg_dict = msg.model_dump()
187207
# Convert lists to comma-separated strings for TAG fields
188208
msg_dict["topics"] = ",".join(msg.topics) if msg.topics else ""
@@ -193,10 +213,9 @@ async def post_memory(
193213
await redis.lpush(messages_key, *messages_json) # type: ignore
194214

195215
# Check if window size is exceeded
196-
current_size = await redis.llen(messages_key)
216+
current_size = await redis.llen(messages_key) # type: ignore
197217
if current_size > settings.window_size:
198218
# Handle compaction in background
199-
model_client = await get_model_client(settings.generation_model)
200219
background_tasks.add_task(
201220
handle_compaction,
202221
session_id,
File renamed without changes.

retrieval.py renamed to redis_memory_server/retrieval.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from fastapi import APIRouter, HTTPException
44

5-
from config import settings
6-
from long_term_memory import search_messages
7-
from models import SearchPayload, SearchResults
8-
from utils import get_openai_client, get_redis_conn
5+
from redis_memory_server.config import settings
6+
from redis_memory_server.long_term_memory import search_messages
7+
from redis_memory_server.models import SearchPayload, SearchResults
8+
from redis_memory_server.utils import get_openai_client, get_redis_conn
99

1010

1111
logger = logging.getLogger(__name__)
@@ -37,8 +37,7 @@ async def run_retrieval(
3737
client = await get_openai_client()
3838

3939
try:
40-
results = await search_messages(payload.text, session_id, client, redis)
41-
return results
40+
return await search_messages(payload.text, session_id, client, redis)
4241
except Exception as e:
4342
logger.error(f"Error in retrieval API: {e}")
44-
raise HTTPException(status_code=500, detail="Internal server error")
43+
raise HTTPException(status_code=500, detail="Internal server error") from e

0 commit comments

Comments
 (0)