Skip to content

Commit 5e65ad2

Browse files
committed
Tests passing
1 parent f2c4e06 commit 5e65ad2

File tree

7 files changed

+98
-136
lines changed

7 files changed

+98
-136
lines changed

long_term_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def search_messages(
9191
.paging(0, limit)
9292
.dialect(2)
9393
)
94-
print(params)
94+
9595
raw_results = await redis_conn.ft(REDIS_INDEX_NAME).search(
9696
q,
9797
query_params=params, # type: ignore

memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
AckResponse,
1414
GetSessionsQuery,
1515
)
16-
from reducer import handle_compaction
16+
from reducers import handle_compaction
1717
from long_term_memory import index_messages
1818
from utils import Keys, get_openai_client, get_redis_conn
1919
from config import settings

reducer.py renamed to reducers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
logger = logging.getLogger(__name__)
88

99

10-
async def incremental_summarization(
10+
async def _incremental_summary(
1111
model: str,
1212
openai_client: OpenAIClientWrapper,
1313
context: Optional[str],
@@ -136,7 +136,7 @@ async def handle_compaction(
136136
return
137137

138138
# Generate new summary
139-
summary, _ = await incremental_summarization(
139+
summary, _ = await _incremental_summary(
140140
model,
141141
openai_client,
142142
context.decode("utf-8") if isinstance(context, bytes) else context,

tests/test_api.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

3-
from httpx import ASGITransport, AsyncClient
43
import numpy as np
54
import pytest
6-
from fastapi import BackgroundTasks, FastAPI
75

86
from config import Settings
9-
from healthcheck import router as health_router
10-
from memory import router as memory_router
7+
from long_term_memory import index_messages
118
from models import (
129
RedisearchResult,
1310
)
14-
from reducer import handle_compaction
15-
from long_term_memory import index_messages
11+
from reducers import handle_compaction
1612

1713

1814
@pytest.fixture
@@ -115,7 +111,7 @@ async def test_post_memory(self, client):
115111
data = response.json()
116112
assert "status" in data
117113
assert data["status"] == "ok"
118-
114+
119115
@pytest.mark.requires_api_keys
120116
@pytest.mark.asyncio
121117
async def test_post_memory_stores_in_long_term_memory(self, client):

tests/test_long_term_memory.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from re import M
2-
from unittest.mock import AsyncMock
2+
from unittest.mock import AsyncMock, MagicMock
33
import pytest
44
import numpy as np
55
from redis.commands.search.document import Document
@@ -56,15 +56,18 @@ async def test_index_messages(
5656
async def test_search_messages(self, mock_openai_client, mock_async_redis_client):
5757
"""Test searching messages"""
5858
# Set up the mock embedding response
59-
mock_openai_client.create_embedding.return_value = [[0.1, 0.2, 0.3, 0.4]]
59+
mock_openai_client.create_embedding.return_value = np.array(
60+
[0.1, 0.2, 0.3, 0.4], dtype=np.float32
61+
)
6062

6163
class MockResult:
6264
def __init__(self, docs):
6365
self.total = len(docs)
6466
self.docs = docs
6567

66-
mock_async_redis_client.ft = AsyncMock()
67-
mock_async_redis_client.ft.search.return_value = MockResult(
68+
# Create a proper mock structure for Redis ft().search()
69+
mock_search = AsyncMock()
70+
mock_search.return_value = MockResult(
6871
[
6972
Document(
7073
id=b"doc1",
@@ -81,6 +84,13 @@ def __init__(self, docs):
8184
]
8285
)
8386

87+
# Create a mock FT object that has a search method
88+
mock_ft = MagicMock()
89+
mock_ft.search = mock_search
90+
91+
# Setup the ft method to return our mock_ft object
92+
mock_async_redis_client.ft = MagicMock(return_value=mock_ft)
93+
8494
# Call search_messages
8595
query = "What is the meaning of life?"
8696
session_id = "test-session"
@@ -91,12 +101,17 @@ def __init__(self, docs):
91101
# Check that create_embedding was called with the right arguments
92102
mock_openai_client.create_embedding.assert_called_with([query])
93103

94-
# Check that redis.execute_command was called with the right arguments
95-
mock_async_redis_client.ft.search.assert_called_once()
96-
args = mock_async_redis_client.ft.search.call_args[0]
97-
98104
# Check that the index name is correct
99-
assert args[1] == REDIS_INDEX_NAME
105+
assert mock_async_redis_client.ft.call_count == 1
106+
assert mock_async_redis_client.ft.call_args[0][0] == REDIS_INDEX_NAME
107+
108+
# Check that search was called with the right arguments
109+
assert mock_ft.search.call_count == 1
110+
args = mock_ft.search.call_args[0]
111+
assert (
112+
args[0]._query_string
113+
== "@session:{test-session}=>[KNN 10 @vector $vec AS dist]"
114+
)
100115

101116
# Check that the results are parsed correctly
102117
assert len(results.docs) == 2
@@ -109,6 +124,7 @@ def __init__(self, docs):
109124
assert results.docs[1].dist == 0.75
110125

111126

127+
@pytest.mark.requires_api_keys
112128
class TestLongTermMemoryIntegration:
113129
@pytest.mark.asyncio
114130
async def test_search_messages(self, memory_messages, async_redis_client):
@@ -149,7 +165,9 @@ async def test_search_messages_with_distance_threshold(
149165
limit=2,
150166
)
151167

152-
assert results.total == 2
168+
assert results.total == 4
169+
assert len(results.docs) == 2
170+
153171
assert results.docs[0].role == "user"
154172
assert results.docs[0].content == "What is the capital of France?"
155173
assert results.docs[1].role == "assistant"

tests/test_models.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import os
3+
import numpy as np
34
from unittest.mock import patch, MagicMock, AsyncMock
4-
from typing import List
55

66
from models import (
77
MemoryMessage,
@@ -11,7 +11,6 @@
1111
RedisearchResult,
1212
OpenAIClientType,
1313
OpenAIClientWrapper,
14-
parse_redisearch_response,
1514
)
1615

1716

@@ -81,49 +80,6 @@ def test_redisearch_result(self):
8180
assert result.dist == 0.75
8281

8382

84-
class TestRedisearchParsing:
85-
def test_parse_redisearch_empty_response(self):
86-
"""Test parsing an empty redisearch response"""
87-
# Test with empty response
88-
assert parse_redisearch_response([]) == []
89-
assert parse_redisearch_response([0]) == []
90-
91-
def test_parse_redisearch_response(self):
92-
"""Test parsing a redisearch response"""
93-
# Create a mock response similar to what Redis would return
94-
mock_response = [
95-
2, # Number of results
96-
b"doc1", # Document ID
97-
[ # Document fields
98-
b"role",
99-
b"user",
100-
b"content",
101-
b"Hello, world!",
102-
b"dist",
103-
b"0.25",
104-
],
105-
b"doc2", # Document ID
106-
[ # Document fields
107-
b"role",
108-
b"assistant",
109-
b"content",
110-
b"Hi there!",
111-
b"dist",
112-
b"0.75",
113-
],
114-
]
115-
116-
results = parse_redisearch_response(mock_response)
117-
118-
assert len(results) == 2
119-
assert results[0].role == "user"
120-
assert results[0].content == "Hello, world!"
121-
assert results[0].dist == 0.25
122-
assert results[1].role == "assistant"
123-
assert results[1].content == "Hi there!"
124-
assert results[1].dist == 0.75
125-
126-
12783
@pytest.mark.asyncio
12884
class TestOpenAIClientWrapper:
12985
@patch.dict(
@@ -171,8 +127,13 @@ async def test_create_embedding(self, mock_init):
171127

172128
# Verify embeddings were created correctly
173129
assert len(embeddings) == 2
174-
assert embeddings[0] == [0.1, 0.2, 0.3]
175-
assert embeddings[1] == [0.4, 0.5, 0.6]
130+
# Convert NumPy array to list or use np.array_equal for comparison
131+
assert np.array_equal(
132+
embeddings[0], np.array([0.1, 0.2, 0.3], dtype=np.float32)
133+
)
134+
assert np.array_equal(
135+
embeddings[1], np.array([0.4, 0.5, 0.6], dtype=np.float32)
136+
)
176137

177138
# Verify the client was called with correct parameters
178139
client.embedding_client.embeddings.create.assert_called_with(

0 commit comments

Comments
 (0)