11from re import M
2- from unittest .mock import AsyncMock
2+ from unittest .mock import AsyncMock , MagicMock
33import pytest
44import numpy as np
55from redis .commands .search .document import Document
@@ -56,15 +56,18 @@ async def test_index_messages(
5656 async def test_search_messages (self , mock_openai_client , mock_async_redis_client ):
5757 """Test searching messages"""
5858 # Set up the mock embedding response
59- mock_openai_client .create_embedding .return_value = [[0.1 , 0.2 , 0.3 , 0.4 ]]
59+ mock_openai_client .create_embedding .return_value = np .array (
60+ [0.1 , 0.2 , 0.3 , 0.4 ], dtype = np .float32
61+ )
6062
6163 class MockResult :
6264 def __init__ (self , docs ):
6365 self .total = len (docs )
6466 self .docs = docs
6567
66- mock_async_redis_client .ft = AsyncMock ()
67- mock_async_redis_client .ft .search .return_value = MockResult (
68+ # Create a proper mock structure for Redis ft().search()
69+ mock_search = AsyncMock ()
70+ mock_search .return_value = MockResult (
6871 [
6972 Document (
7073 id = b"doc1" ,
@@ -81,6 +84,13 @@ def __init__(self, docs):
8184 ]
8285 )
8386
87+ # Create a mock FT object that has a search method
88+ mock_ft = MagicMock ()
89+ mock_ft .search = mock_search
90+
91+ # Setup the ft method to return our mock_ft object
92+ mock_async_redis_client .ft = MagicMock (return_value = mock_ft )
93+
8494 # Call search_messages
8595 query = "What is the meaning of life?"
8696 session_id = "test-session"
@@ -91,12 +101,17 @@ def __init__(self, docs):
91101 # Check that create_embedding was called with the right arguments
92102 mock_openai_client .create_embedding .assert_called_with ([query ])
93103
94- # Check that redis.execute_command was called with the right arguments
95- mock_async_redis_client .ft .search .assert_called_once ()
96- args = mock_async_redis_client .ft .search .call_args [0 ]
97-
98104 # Check that the index name is correct
99- assert args [1 ] == REDIS_INDEX_NAME
105+ assert mock_async_redis_client .ft .call_count == 1
106+ assert mock_async_redis_client .ft .call_args [0 ][0 ] == REDIS_INDEX_NAME
107+
108+ # Check that search was called with the right arguments
109+ assert mock_ft .search .call_count == 1
110+ args = mock_ft .search .call_args [0 ]
111+ assert (
112+ args [0 ]._query_string
113+ == "@session:{test-session}=>[KNN 10 @vector $vec AS dist]"
114+ )
100115
101116 # Check that the results are parsed correctly
102117 assert len (results .docs ) == 2
@@ -109,6 +124,7 @@ def __init__(self, docs):
109124 assert results .docs [1 ].dist == 0.75
110125
111126
127+ @pytest .mark .requires_api_keys
112128class TestLongTermMemoryIntegration :
113129 @pytest .mark .asyncio
114130 async def test_search_messages (self , memory_messages , async_redis_client ):
@@ -149,7 +165,9 @@ async def test_search_messages_with_distance_threshold(
149165 limit = 2 ,
150166 )
151167
152- assert results .total == 2
168+ assert results .total == 4
169+ assert len (results .docs ) == 2
170+
153171 assert results .docs [0 ].role == "user"
154172 assert results .docs [0 ].content == "What is the capital of France?"
155173 assert results .docs [1 ].role == "assistant"
0 commit comments