Skip to content

Commit f0c26ce

Browse files
committed
Simplify reducers
1 parent b5c6936 commit f0c26ce

File tree

8 files changed

+115
-72
lines changed

8 files changed

+115
-72
lines changed

memory.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MemoryMessagesAndContext,
1414
MemoryResponse,
1515
)
16-
from reducers import handle_compaction
16+
from summarization import handle_compaction
1717
from utils import Keys, get_model_client, get_openai_client, get_redis_conn
1818

1919

@@ -24,7 +24,7 @@
2424

2525
@router.get("/sessions/", response_model=list[str])
2626
async def get_sessions(
27-
pagination: GetSessionsQuery = Depends(GetSessionsQuery),
27+
pagination: GetSessionsQuery = Depends(),
2828
):
2929
"""
3030
Get a list of session IDs, with optional pagination
@@ -55,14 +55,11 @@ async def get_sessions(
5555
session_ids = await redis.zrange(sessions_key, start, end)
5656

5757
# Convert from bytes to strings if needed
58-
session_ids = [
59-
s.decode("utf-8") if isinstance(s, bytes) else s for s in session_ids
60-
]
58+
return [s.decode("utf-8") if isinstance(s, bytes) else s for s in session_ids]
6159

62-
return session_ids
6360
except Exception as e:
6461
logger.error(f"Error getting sessions: {e}")
65-
raise HTTPException(status_code=500, detail="Internal server error")
62+
raise HTTPException(status_code=500, detail="Internal server error") from e
6663

6764

6865
@router.get("/sessions/{session_id}/memory", response_model=MemoryResponse)
@@ -72,7 +69,6 @@ async def get_memory(session_id: str):
7269
7370
Args:
7471
session_id: The session ID
75-
request: FastAPI request
7672
7773
Returns:
7874
Memory response with messages and context
@@ -103,8 +99,19 @@ async def get_memory(session_id: str):
10399
msg_raw = msg_raw.decode("utf-8")
104100

105101
# Parse JSON
106-
msg = json.loads(msg_raw)
107-
memory_messages.append(MemoryMessage(**msg))
102+
msg_dict = json.loads(msg_raw)
103+
104+
# Convert comma-separated strings back to lists for topics and entities
105+
if "topics" in msg_dict:
106+
msg_dict["topics"] = (
107+
msg_dict["topics"].split(",") if msg_dict["topics"] else []
108+
)
109+
if "entities" in msg_dict:
110+
msg_dict["entities"] = (
111+
msg_dict["entities"].split(",") if msg_dict["entities"] else []
112+
)
113+
114+
memory_messages.append(MemoryMessage(**msg_dict))
108115

109116
# Extract context and tokens
110117
context = None
@@ -128,14 +135,15 @@ async def get_memory(session_id: str):
128135
tokens = int(tokens_str)
129136

130137
# Build response
131-
response = MemoryResponse(
132-
messages=memory_messages, context=context, tokens=tokens
138+
return MemoryResponse(
139+
messages=memory_messages,
140+
context=context,
141+
tokens=tokens,
133142
)
134143

135-
return response
136144
except Exception as e:
137145
logger.error(f"Error getting memory for session {session_id}: {e}")
138-
raise HTTPException(status_code=500, detail="Internal server error")
146+
raise HTTPException(status_code=500, detail="Internal server error") from e
139147

140148

141149
@router.post("/sessions/{session_id}/memory", response_model=AckResponse)
@@ -172,22 +180,22 @@ async def post_memory(
172180
current_time = int(time.time())
173181
await redis.zadd(sessions_key, {session_id: current_time})
174182

175-
# Add messages to session list
176-
# TODO: Don't need a pipeline here, lpush takes multiple values.
177-
pipe = redis.pipeline()
183+
# Convert messages to JSON, handling topics and entities
184+
messages_json = []
178185
for msg in memory_messages.messages:
179-
# Convert to dict and serialize
180-
msg_json = json.dumps(msg.model_dump())
181-
pipe.lpush(messages_key, msg_json)
186+
msg_dict = msg.model_dump()
187+
# Convert lists to comma-separated strings for TAG fields
188+
msg_dict["topics"] = ",".join(msg.topics) if msg.topics else ""
189+
msg_dict["entities"] = ",".join(msg.entities) if msg.entities else ""
190+
messages_json.append(json.dumps(msg_dict))
182191

183-
# Execute pipeline
184-
await pipe.execute()
192+
# Add messages to list
193+
await redis.lpush(messages_key, *messages_json) # type: ignore
185194

186195
# Check if window size is exceeded
187196
current_size = await redis.llen(messages_key)
188197
if current_size > settings.window_size:
189198
# Handle compaction in background
190-
# Get the appropriate client for the generation model
191199
model_client = await get_model_client(settings.generation_model)
192200
background_tasks.add_task(
193201
handle_compaction,
@@ -198,26 +206,21 @@ async def post_memory(
198206
redis,
199207
)
200208

201-
# If long-term memory is enabled, index messages.
202-
#
203-
# TODO: Add support for custom policies around when to index and/or
204-
# avoid re-indexing duplicate content.
209+
# If long-term memory is enabled, index messages
205210
if settings.long_term_memory:
206-
# For embeddings, we always use OpenAI models since Anthropic doesn't support embeddings
207211
embedding_client = await get_openai_client()
208-
209212
background_tasks.add_task(
210213
index_messages,
211214
memory_messages.messages,
212215
session_id,
213-
embedding_client, # Explicitly use OpenAI client for embeddings
216+
embedding_client,
214217
redis,
215218
)
216219

217220
return AckResponse(status="ok")
218221
except Exception as e:
219222
logger.error(f"Error adding messages for session {session_id}: {e}")
220-
raise HTTPException(status_code=500, detail="Internal server error")
223+
raise HTTPException(status_code=500, detail="Internal server error") from e
221224

222225

223226
@router.delete("/sessions/{session_id}/memory", response_model=AckResponse)
@@ -252,5 +255,4 @@ async def delete_memory(
252255
return AckResponse(status="ok")
253256
except Exception as e:
254257
logger.error(f"Error deleting memory for session {session_id}: {e}")
255-
raise
256-
raise HTTPException(status_code=500, detail="Internal server error")
258+
raise HTTPException(status_code=500, detail="Internal server error") from e

models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ class MemoryMessage(BaseModel):
1818

1919
role: str
2020
content: str
21+
topics: list[str] = Field(
22+
default_factory=list, description="List of topics associated with this message"
23+
)
24+
entities: list[str] = Field(
25+
default_factory=list, description="List of entities mentioned in this message"
26+
)
2127

2228

2329
class MemoryMessagesAndContext(BaseModel):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ target-version = "py312"
2121
# Enable various rules
2222
select = ["E", "F", "B", "I", "N", "UP", "C4", "RET", "SIM", "TID"]
2323
# Exclude COM812 which conflicts with the formatter
24-
ignore = ["COM812", "E501"]
24+
ignore = ["COM812", "E501", "B008"]
2525

2626
# Allow unused variables when underscore-prefixed
2727
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

reducers.py renamed to summarization.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json
12
import logging
23

34
import tiktoken
45
from redis.asyncio import Redis
56

67
from models import (
78
AnthropicClientWrapper,
9+
MemoryMessage,
810
OpenAIClientWrapper,
911
get_model_config,
1012
)
@@ -36,7 +38,7 @@ async def _incremental_summary(
3638
messages_joined = "\n".join(messages)
3739
prev_summary = context or ""
3840

39-
# Prompt template for progressive summarization (from langchain)
41+
# Prompt template for progressive summarization
4042
progressive_prompt = f"""
4143
Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary. If the lines are meaningless just return NONE
4244
@@ -110,9 +112,20 @@ async def handle_compaction(
110112
pipe.get(context_key)
111113
results = await pipe.execute()
112114

113-
messages = results[0]
115+
messages_raw = results[0]
114116
context = results[1]
115117

118+
# Parse messages
119+
messages = []
120+
for msg_raw in messages_raw:
121+
if isinstance(msg_raw, bytes):
122+
msg_raw = msg_raw.decode("utf-8")
123+
msg_dict = json.loads(msg_raw)
124+
messages.append(MemoryMessage(**msg_dict))
125+
126+
# Get context string
127+
context_str = context.decode("utf-8") if isinstance(context, bytes) else context
128+
116129
# Get model configuration for token limits
117130
model_config = get_model_config(model)
118131

@@ -124,22 +137,19 @@ async def handle_compaction(
124137
buffer_tokens = 230
125138
max_message_tokens = max_tokens - summary_max_tokens - buffer_tokens
126139

127-
# Initialize encoding (currently uses OpenAI's tokenizer, but could be extended for different models)
140+
# Initialize encoding
128141
encoding = tiktoken.get_encoding("cl100k_base")
129142

130143
# Check token count of messages
131144
total_tokens = 0
132145
messages_to_summarize = []
133146

134147
for msg in messages:
135-
# Decode message if needed
136-
if isinstance(msg, bytes):
137-
msg = msg.decode("utf-8")
138-
139-
msg_tokens = len(encoding.encode(msg))
148+
msg_str = json.dumps(msg.model_dump())
149+
msg_tokens = len(encoding.encode(msg_str))
140150
if total_tokens + msg_tokens <= max_message_tokens:
141151
total_tokens += msg_tokens
142-
messages_to_summarize.append(msg)
152+
messages_to_summarize.append(msg_str)
143153

144154
# Skip if no messages to summarize
145155
if not messages_to_summarize:
@@ -150,7 +160,7 @@ async def handle_compaction(
150160
summary, _ = await _incremental_summary(
151161
model,
152162
client,
153-
context.decode("utf-8") if isinstance(context, bytes) else context,
163+
context_str,
154164
messages_to_summarize,
155165
)
156166

tests/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
RedisearchResult,
1010
SearchResults,
1111
)
12-
from reducers import handle_compaction
12+
from summarization import handle_compaction
1313

1414

1515
@pytest.fixture

tests/test_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,36 @@ def test_memory_message(self):
2323
msg = MemoryMessage(role="user", content="Hello, world!")
2424
assert msg.role == "user"
2525
assert msg.content == "Hello, world!"
26+
assert msg.topics == [] # Check default empty list
27+
assert msg.entities == [] # Check default empty list
2628

2729
# Test serialization
2830
data = msg.model_dump()
29-
assert data == {"role": "user", "content": "Hello, world!"}
31+
assert data == {
32+
"role": "user",
33+
"content": "Hello, world!",
34+
"topics": [],
35+
"entities": [],
36+
}
37+
38+
# Test with topics and entities
39+
msg_with_metadata = MemoryMessage(
40+
role="user",
41+
content="Hello, world!",
42+
topics=["greeting", "general"],
43+
entities=["world"],
44+
)
45+
assert msg_with_metadata.topics == ["greeting", "general"]
46+
assert msg_with_metadata.entities == ["world"]
47+
48+
# Test serialization with metadata
49+
data = msg_with_metadata.model_dump()
50+
assert data == {
51+
"role": "user",
52+
"content": "Hello, world!",
53+
"topics": ["greeting", "general"],
54+
"entities": ["world"],
55+
}
3056

3157
def test_memory_messages_and_context(self):
3258
"""Test MemoryMessagesAndContext model"""

0 commit comments

Comments
 (0)