diff --git a/agent-memory-client/agent_memory_client/__init__.py b/agent-memory-client/agent_memory_client/__init__.py index 7647c8b..afd6729 100644 --- a/agent-memory-client/agent_memory_client/__init__.py +++ b/agent-memory-client/agent_memory_client/__init__.py @@ -5,7 +5,7 @@ memory management capabilities for AI agents and applications. """ -__version__ = "0.9.2" +__version__ = "0.10.0" from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client from .exceptions import ( diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index 2eb3ca6..6d58ba6 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -574,12 +574,13 @@ async def search_long_term_memory( memory_type: MemoryType | dict[str, Any] | None = None, limit: int = 10, offset: int = 0, + optimize_query: bool = True, ) -> MemoryRecordResults: """ Search long-term memories using semantic search and filters. Args: - text: Search query text for semantic similarity + text: Query for vector search - will be used for semantic similarity matching session_id: Optional session ID filter namespace: Optional namespace filter topics: Optional topics filter @@ -591,6 +592,7 @@ async def search_long_term_memory( memory_type: Optional memory type filter limit: Maximum number of results to return (default: 10) offset: Offset for pagination (default: 0) + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: MemoryRecordResults with matching memories and metadata @@ -669,10 +671,14 @@ async def search_long_term_memory( if distance_threshold is not None: payload["distance_threshold"] = distance_threshold + # Add optimize_query as query parameter + params = {"optimize_query": str(optimize_query).lower()} + try: response = await self._client.post( "/v1/long-term-memory/search", json=payload, + params=params, ) response.raise_for_status() return MemoryRecordResults(**response.json()) @@ -691,6 +697,7 @@ async def search_memory_tool( max_results: int = 5, min_relevance: float | None = None, user_id: str | None = None, + optimize_query: bool = False, ) -> dict[str, Any]: """ Simplified long-term memory search designed for LLM tool use. @@ -701,13 +708,14 @@ async def search_memory_tool( searches long-term memory, not working memory. Args: - query: The search query text + query: The query for vector search topics: Optional list of topic strings to filter by entities: Optional list of entity strings to filter by memory_type: Optional memory type ("episodic", "semantic", "message") max_results: Maximum results to return (default: 5) min_relevance: Optional minimum relevance score (0.0-1.0) user_id: Optional user ID to filter memories by + optimize_query: Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries) Returns: Dict with 'memories' list and 'summary' for LLM consumption @@ -759,6 +767,7 @@ async def search_memory_tool( distance_threshold=distance_threshold, limit=max_results, user_id=user_id_filter, + optimize_query=optimize_query, ) # Format for LLM consumption @@ -828,13 +837,13 @@ async def handle_tool_calls(client, tool_calls): "type": "function", "function": { "name": "search_memory", - "description": "Search long-term memory for relevant information based on a query. Use this when you need to recall past conversations, user preferences, or previously stored information. Note: This searches only long-term memory, not current working memory.", + "description": "Search long-term memory for relevant information using a query for vector search. Use this when you need to recall past conversations, user preferences, or previously stored information. Note: This searches only long-term memory, not current working memory.", "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": "The search query describing what information you're looking for", + "description": "The query for vector search describing what information you're looking for", }, "topics": { "type": "array", @@ -868,6 +877,11 @@ async def handle_tool_calls(client, tool_calls): "type": "string", "description": "Optional user ID to filter memories by (e.g., 'user123')", }, + "optimize_query": { + "type": "boolean", + "default": False, + "description": "Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries)", + }, }, "required": ["query"], }, @@ -2138,6 +2152,7 @@ async def memory_prompt( context_window_max: int | None = None, long_term_search: dict[str, Any] | None = None, user_id: str | None = None, + optimize_query: bool = True, ) -> dict[str, Any]: """ Hydrate a user query with memory context and return a prompt ready to send to an LLM. @@ -2145,13 +2160,14 @@ async def memory_prompt( NOTE: `long_term_search` uses the same filter options as `search_long_term_memories`. Args: - query: The input text to find relevant context for + query: The query for vector search to find relevant context for session_id: Optional session ID to include session messages namespace: Optional namespace for the session model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens long_term_search: Optional search parameters for long-term memory user_id: Optional user ID for the session + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: Dict with messages hydrated with relevant memory context @@ -2208,10 +2224,14 @@ async def memory_prompt( } payload["long_term_search"] = long_term_search + # Add optimize_query as query parameter + params = {"optimize_query": str(optimize_query).lower()} + try: response = await self._client.post( "/v1/memory/prompt", json=payload, + params=params, ) response.raise_for_status() result = response.json() @@ -2235,6 +2255,7 @@ async def hydrate_memory_prompt( distance_threshold: float | None = None, memory_type: dict[str, Any] | None = None, limit: int = 10, + optimize_query: bool = True, ) -> dict[str, Any]: """ Hydrate a user query with long-term memory context using filters. @@ -2243,7 +2264,7 @@ async def hydrate_memory_prompt( long-term memory search with the specified filters. Args: - query: The input text to find relevant context for + query: The query for vector search to find relevant context for session_id: Optional session ID filter (as dict) namespace: Optional namespace filter (as dict) topics: Optional topics filter (as dict) @@ -2254,6 +2275,7 @@ async def hydrate_memory_prompt( distance_threshold: Optional distance threshold memory_type: Optional memory type filter (as dict) limit: Maximum number of long-term memories to include + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: Dict with messages hydrated with relevant long-term memories @@ -2285,6 +2307,7 @@ async def hydrate_memory_prompt( return await self.memory_prompt( query=query, long_term_search=long_term_search, + optimize_query=optimize_query, ) def _deep_merge_dicts( diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index a16efad..a0c454e 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -494,6 +494,7 @@ async def create_long_term_memory( @router.post("/v1/long-term-memory/search", response_model=MemoryRecordResultsResponse) async def search_long_term_memory( payload: SearchRequest, + optimize_query: bool = True, current_user: UserInfo = Depends(get_current_user), ): """ @@ -501,6 +502,7 @@ async def search_long_term_memory( Args: payload: Search payload with filter objects for precise queries + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: List of search results @@ -517,6 +519,7 @@ async def search_long_term_memory( "distance_threshold": payload.distance_threshold, "limit": payload.limit, "offset": payload.offset, + "optimize_query": optimize_query, **filters, } @@ -549,13 +552,14 @@ async def delete_long_term_memory( @router.post("/v1/memory/prompt", response_model=MemoryPromptResponse) async def memory_prompt( params: MemoryPromptRequest, + optimize_query: bool = True, current_user: UserInfo = Depends(get_current_user), ) -> MemoryPromptResponse: """ Hydrate a user query with memory context and return a prompt ready to send to an LLM. - `query` is the input text that the caller of this API wants to use to find + `query` is the query for vector search that the caller of this API wants to use to find relevant context. If `session_id` is provided and matches an existing session, the resulting prompt will include those messages as the immediate history of messages leading to a message containing `query`. @@ -566,6 +570,7 @@ async def memory_prompt( Args: params: MemoryPromptRequest + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: List of messages to send to an LLM, hydrated with relevant memory context @@ -671,6 +676,7 @@ async def memory_prompt( logger.debug(f"[memory_prompt] Search payload: {search_payload}") long_term_memories = await search_long_term_memory( search_payload, + optimize_query=optimize_query, ) logger.debug(f"[memory_prompt] Long-term memories: {long_term_memories}") diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 35bba92..b9f9e50 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -56,6 +56,12 @@ class Settings(BaseSettings): anthropic_api_base: str | None = None generation_model: str = "gpt-4o" embedding_model: str = "text-embedding-3-small" + + # Model selection for query optimization + slow_model: str = "gpt-4o" # Slower, more capable model for complex tasks + fast_model: str = ( + "gpt-4o-mini" # Faster, smaller model for quick tasks like query optimization + ) port: int = 8000 mcp_port: int = 9000 @@ -124,6 +130,21 @@ class Settings(BaseSettings): 0.7 # Fraction of context window that triggers summarization ) + # Query optimization settings + query_optimization_prompt_template: str = """Transform this natural language query into an optimized version for semantic search. The goal is to make it more effective for finding semantically similar content while preserving the original intent. + +Guidelines: +- Keep the core meaning and intent +- Use more specific and descriptive terms +- Remove unnecessary words like "tell me", "I want to know", "can you" +- Focus on the key concepts and topics +- Make it concise but comprehensive + +Original query: {query} + +Optimized query:""" + min_optimized_query_length: int = 2 + # Other Application settings log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" default_mcp_user_id: str | None = None diff --git a/agent_memory_server/llms.py b/agent_memory_server/llms.py index 18537a7..8653026 100644 --- a/agent_memory_server/llms.py +++ b/agent_memory_server/llms.py @@ -423,3 +423,72 @@ async def get_model_client( raise ValueError(f"Unsupported model provider: {model_config.provider}") return _model_clients[model_name] + + +async def optimize_query_for_vector_search( + query: str, + model_name: str | None = None, +) -> str: + """ + Optimize a user query for vector search using a fast model. + + This function takes a natural language query and rewrites it to be more effective + for semantic similarity search. It uses a fast, small model to improve search + performance while maintaining query intent. + + Args: + query: The original user query to optimize + model_name: Model to use for optimization (defaults to settings.fast_model) + + Returns: + Optimized query string better suited for vector search + """ + if not query or not query.strip(): + return query + + # Use fast model from settings if not specified + effective_model = model_name or settings.fast_model + + # Create optimization prompt from config template + optimization_prompt = settings.query_optimization_prompt_template.format( + query=query + ) + + try: + client = await get_model_client(effective_model) + + response = await client.create_chat_completion( + model=effective_model, + prompt=optimization_prompt, + ) + + if ( + hasattr(response, "choices") + and response.choices + and len(response.choices) > 0 + ): + optimized = "" + if hasattr(response.choices[0], "message"): + optimized = response.choices[0].message.content + elif hasattr(response.choices[0], "text"): + optimized = response.choices[0].text + else: + optimized = str(response.choices[0]) + + # Clean up the response + optimized = optimized.strip() + + # Fallback to original if optimization failed + if not optimized or len(optimized) < settings.min_optimized_query_length: + logger.warning(f"Query optimization failed for: {query}") + return query + + logger.debug(f"Optimized query: '{query}' -> '{optimized}'") + return optimized + + except Exception as e: + logger.warning(f"Failed to optimize query '{query}': {e}") + # Return original query if optimization fails + return query + + return query diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index 1f60144..66d6aec 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -28,6 +28,7 @@ AnthropicClientWrapper, OpenAIClientWrapper, get_model_client, + optimize_query_for_vector_search, ) from agent_memory_server.models import ( ExtractedMemoryRecord, @@ -718,13 +719,13 @@ async def search_long_term_memories( memory_hash: MemoryHash | None = None, limit: int = 10, offset: int = 0, + optimize_query: bool = True, ) -> MemoryRecordResults: """ Search for long-term memories using the pluggable VectorStore adapter. Args: - text: Search query text - redis: Redis client (kept for compatibility but may be unused depending on backend) + text: Query for vector search - will be used for semantic similarity matching session_id: Optional session ID filter user_id: Optional user ID filter namespace: Optional namespace filter @@ -738,16 +739,22 @@ async def search_long_term_memories( memory_hash: Optional memory hash filter limit: Maximum number of results offset: Offset for pagination + optimize_query: Whether to optimize the query for vector search using a fast model (default: True) Returns: MemoryRecordResults containing matching memories """ + # Optimize query for vector search if requested + search_query = text + if optimize_query and text: + search_query = await optimize_query_for_vector_search(text) + # Get the VectorStore adapter adapter = await get_vectorstore_adapter() # Delegate search to the adapter return await adapter.search_memories( - query=text, + query=search_query, session_id=session_id, user_id=user_id, namespace=namespace, diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index c5fc264..18a50f7 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -330,13 +330,14 @@ async def search_long_term_memory( distance_threshold: float | None = None, limit: int = 10, offset: int = 0, + optimize_query: bool = False, ) -> MemoryRecordResults: """ - Search for memories related to a text query. + Search for memories related to a query for vector search. Finds memories based on a combination of semantic similarity and input filters. - This tool performs a semantic search on stored memories using the query text and filters + This tool performs a semantic search on stored memories using the query for vector search and filters in the payload. Results are ranked by relevance. DATETIME INPUT FORMAT: @@ -413,7 +414,7 @@ async def search_long_term_memory( ``` Args: - text: The semantic search query text (required). Use empty string "" to get all memories for a user. + text: The query for vector search (required). Use empty string "" to get all memories for a user. session_id: Filter by session ID namespace: Filter by namespace topics: Filter by topics @@ -425,6 +426,7 @@ async def search_long_term_memory( distance_threshold: Distance threshold for semantic search limit: Maximum number of results offset: Offset for pagination + optimize_query: Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries) Returns: MemoryRecordResults containing matched memories sorted by relevance @@ -449,7 +451,9 @@ async def search_long_term_memory( limit=limit, offset=offset, ) - results = await core_search_long_term_memory(payload) + results = await core_search_long_term_memory( + payload, optimize_query=optimize_query + ) results = MemoryRecordResults( total=results.total, memories=results.memories, @@ -485,18 +489,19 @@ async def memory_prompt( distance_threshold: float | None = None, limit: int = 10, offset: int = 0, + optimize_query: bool = False, ) -> MemoryPromptResponse: """ - Hydrate a user query with relevant session history and long-term memories. + Hydrate a query for vector search with relevant session history and long-term memories. - This tool enriches the user's query by retrieving: + This tool enriches the query by retrieving: 1. Context from the current conversation session 2. Relevant long-term memories related to the query The tool returns both the relevant memories AND the user's query in a format ready for generating comprehensive responses. - The function uses the query field from the payload as the user's query, + The function uses the query field as the query for vector search, and any filters to retrieve relevant memories. DATETIME INPUT FORMAT: @@ -561,7 +566,7 @@ async def memory_prompt( ``` Args: - - query: The user's query + - query: The query for vector search - session_id: Add conversation history from a working memory session - namespace: Filter session and long-term memory namespace - topics: Search for long-term memories matching topics @@ -572,6 +577,7 @@ async def memory_prompt( - distance_threshold: Distance threshold for semantic search - limit: Maximum number of long-term memory results - offset: Offset for pagination of long-term memory results + - optimize_query: Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries) Returns: A list of messages, including memory context and the user's query @@ -611,7 +617,10 @@ async def memory_prompt( if search_payload is not None: _params["long_term_search"] = search_payload - return await core_memory_prompt(params=MemoryPromptRequest(query=query, **_params)) + return await core_memory_prompt( + params=MemoryPromptRequest(query=query, **_params), + optimize_query=optimize_query, + ) @mcp_app.tool() diff --git a/tests/test_api.py b/tests/test_api.py index f7fb129..e7dabae 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -361,15 +361,92 @@ async def test_search(self, mock_search, client): assert data["total"] == 2 assert len(data["memories"]) == 2 - # Check first result - assert data["memories"][0]["id"] == "1" - assert data["memories"][0]["text"] == "User: Hello, world!" - assert data["memories"][0]["dist"] == 0.25 + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @pytest.mark.asyncio + async def test_search_with_optimize_query_true(self, mock_search, client): + """Test search endpoint with optimize_query=True (default).""" + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult(id="1", text="Optimized result", dist=0.1), + ], + next_offset=None, + ) - # Check second result - assert data["memories"][1]["id"] == "2" - assert data["memories"][1]["text"] == "Assistant: Hi there!" - assert data["memories"][1]["dist"] == 0.75 + payload = {"text": "tell me about my preferences"} + + # Call endpoint without optimize_query parameter (should default to True) + response = await client.post("/v1/long-term-memory/search", json=payload) + + assert response.status_code == 200 + + # Verify search was called with optimize_query=True (default) + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is True + + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @pytest.mark.asyncio + async def test_search_with_optimize_query_false(self, mock_search, client): + """Test search endpoint with optimize_query=False.""" + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult(id="1", text="Non-optimized result", dist=0.1), + ], + next_offset=None, + ) + + payload = {"text": "tell me about my preferences"} + + # Call endpoint with optimize_query=False as query parameter + response = await client.post( + "/v1/long-term-memory/search", + json=payload, + params={"optimize_query": "false"}, + ) + + assert response.status_code == 200 + + # Verify search was called with optimize_query=False + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is False + + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @pytest.mark.asyncio + async def test_search_with_optimize_query_explicit_true(self, mock_search, client): + """Test search endpoint with explicit optimize_query=True.""" + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult(id="1", text="Optimized result", dist=0.1), + ], + next_offset=None, + ) + + payload = {"text": "what are my UI settings"} + + # Call endpoint with explicit optimize_query=True + response = await client.post( + "/v1/long-term-memory/search", + json=payload, + params={"optimize_query": "true"}, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify search was called with optimize_query=True + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is True + + # Check response structure + assert "memories" in data + assert len(data["memories"]) == 1 + assert data["memories"][0]["id"] == "1" + assert data["memories"][0]["text"] == "Optimized result" @pytest.mark.requires_api_keys @@ -639,6 +716,89 @@ async def test_memory_prompt_with_model_name( # Verify the working memory function was called mock_get_working_memory.assert_called_once() + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @patch("agent_memory_server.api.working_memory.get_working_memory") + @pytest.mark.asyncio + async def test_memory_prompt_with_optimize_query_default_true( + self, mock_get_working_memory, mock_search, client + ): + """Test memory prompt endpoint with default optimize_query=True.""" + # Mock working memory + mock_get_working_memory.return_value = WorkingMemoryResponse( + session_id="test-session", + messages=[ + MemoryMessage(role="user", content="Hello"), + MemoryMessage(role="assistant", content="Hi there"), + ], + memories=[], + context=None, + ) + + # Mock search for long-term memory + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult(id="1", text="User preferences about UI", dist=0.1), + ], + next_offset=None, + ) + + payload = { + "query": "what are my preferences?", + "session": {"session_id": "test-session"}, + "long_term_search": {"text": "preferences"}, + } + + # Call endpoint without optimize_query parameter (should default to True) + response = await client.post("/v1/memory/prompt", json=payload) + + assert response.status_code == 200 + + # Verify search was called with optimize_query=True (default) + mock_search.assert_called_once() + # The search is called indirectly through the API's search_long_term_memory function + # which should have optimize_query=True by default + + @patch("agent_memory_server.api.long_term_memory.search_long_term_memories") + @patch("agent_memory_server.api.working_memory.get_working_memory") + @pytest.mark.asyncio + async def test_memory_prompt_with_optimize_query_false( + self, mock_get_working_memory, mock_search, client + ): + """Test memory prompt endpoint with optimize_query=False.""" + # Mock working memory + mock_get_working_memory.return_value = WorkingMemoryResponse( + session_id="test-session", + messages=[ + MemoryMessage(role="user", content="Hello"), + MemoryMessage(role="assistant", content="Hi there"), + ], + memories=[], + context=None, + ) + + # Mock search for long-term memory + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult(id="1", text="User preferences about UI", dist=0.1), + ], + next_offset=None, + ) + + payload = { + "query": "what are my preferences?", + "session": {"session_id": "test-session"}, + "long_term_search": {"text": "preferences"}, + } + + # Call endpoint with optimize_query=False as query parameter + response = await client.post( + "/v1/memory/prompt", json=payload, params={"optimize_query": "false"} + ) + + assert response.status_code == 200 + @pytest.mark.requires_api_keys class TestLongTermMemoryEndpoint: diff --git a/tests/test_client_api.py b/tests/test_client_api.py index 8235652..63df23c 100644 --- a/tests/test_client_api.py +++ b/tests/test_client_api.py @@ -487,3 +487,189 @@ async def test_memory_prompt_integration(memory_test_client: MemoryAPIClient): assert any("favorite color is blue" in text for text in message_texts) # And the query itself assert query in message_texts[-1] + + +@pytest.mark.asyncio +async def test_search_long_term_memory_with_optimize_query_default_true( + memory_test_client: MemoryAPIClient, +): + """Test that client search_long_term_memory uses optimize_query=True by default.""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult( + id="test-1", + text="User preferences about UI", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + next_offset=None, + ) + + # Call search without optimize_query parameter (should default to True) + results = await memory_test_client.search_long_term_memory( + text="tell me about my preferences" + ) + + # Verify search was called with optimize_query=True (default) + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is True + + # Verify results + assert results.total == 1 + assert len(results.memories) == 1 + + +@pytest.mark.asyncio +async def test_search_long_term_memory_with_optimize_query_false_explicit( + memory_test_client: MemoryAPIClient, +): + """Test that client search_long_term_memory can use optimize_query=False when explicitly set.""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult( + id="test-1", + text="User preferences about UI", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + next_offset=None, + ) + + # Call search with explicit optimize_query=False + await memory_test_client.search_long_term_memory( + text="tell me about my preferences", optimize_query=False + ) + + # Verify search was called with optimize_query=False + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is False + + +@pytest.mark.asyncio +async def test_search_memory_tool_with_optimize_query_false_default( + memory_test_client: MemoryAPIClient, +): + """Test that client search_memory_tool uses optimize_query=False by default (for LLM tool use).""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult( + id="test-1", + text="User preferences about UI", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + next_offset=None, + ) + + # Call search_memory_tool without optimize_query parameter (should default to False for LLM tools) + results = await memory_test_client.search_memory_tool( + query="tell me about my preferences" + ) + + # Verify search was called with optimize_query=False (default for LLM tools) + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is False + + # Verify results format is suitable for LLM consumption + assert "memories" in results + assert "summary" in results + + +@pytest.mark.asyncio +async def test_search_memory_tool_with_optimize_query_true_explicit( + memory_test_client: MemoryAPIClient, +): + """Test that client search_memory_tool can use optimize_query=True when explicitly set.""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=1, + memories=[ + MemoryRecordResult( + id="test-1", + text="User preferences about UI", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + next_offset=None, + ) + + # Call search_memory_tool with explicit optimize_query=True + await memory_test_client.search_memory_tool( + query="tell me about my preferences", optimize_query=True + ) + + # Verify search was called with optimize_query=True + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is True + + +@pytest.mark.asyncio +async def test_memory_prompt_with_optimize_query_default_true( + memory_test_client: MemoryAPIClient, +): + """Test that client memory_prompt uses optimize_query=True by default.""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=0, memories=[], next_offset=None + ) + + # Call memory_prompt without optimize_query parameter (should default to True) + result = await memory_test_client.memory_prompt( + query="what are my preferences?", long_term_search={"text": "preferences"} + ) + + # Verify search was called with optimize_query=True (default) + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is True + assert result is not None + + +@pytest.mark.asyncio +async def test_memory_prompt_with_optimize_query_false_explicit( + memory_test_client: MemoryAPIClient, +): + """Test that client memory_prompt can use optimize_query=False when explicitly set.""" + with patch( + "agent_memory_server.long_term_memory.search_long_term_memories" + ) as mock_search: + mock_search.return_value = MemoryRecordResultsResponse( + total=0, memories=[], next_offset=None + ) + + # Call memory_prompt with explicit optimize_query=False + result = await memory_test_client.memory_prompt( + query="what are my preferences?", + long_term_search={"text": "preferences"}, + optimize_query=False, + ) + + # Verify search was called with optimize_query=False + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("optimize_query") is False + assert result is not None diff --git a/tests/test_llms.py b/tests/test_llms.py index 29dea80..42a8a52 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -9,6 +9,7 @@ OpenAIClientWrapper, get_model_client, get_model_config, + optimize_query_for_vector_search, ) @@ -143,3 +144,190 @@ async def test_get_model_client(): mock_anthropic.return_value = "anthropic-client" client = await get_model_client("claude-3-sonnet-20240229") assert client == "anthropic-client" + + +@pytest.mark.asyncio +class TestQueryOptimization: + """Test query optimization functionality.""" + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_success(self, mock_get_client): + """Test successful query optimization.""" + # Mock the model client and response + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[ + 0 + ].message.content = "user interface preferences dark mode" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + result = await optimize_query_for_vector_search( + "Can you tell me about my UI preferences for dark mode?" + ) + + assert result == "user interface preferences dark mode" + mock_get_client.assert_called_once() + mock_client.create_chat_completion.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_with_custom_model(self, mock_get_client): + """Test query optimization with custom model.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "optimized query" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + result = await optimize_query_for_vector_search( + "original query", model_name="custom-model" + ) + + assert result == "optimized query" + mock_client.create_chat_completion.assert_called_once() + # Verify the model name was passed to create_chat_completion + call_kwargs = mock_client.create_chat_completion.call_args[1] + assert call_kwargs["model"] == "custom-model" + + @patch("agent_memory_server.llms.settings") + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_uses_fast_model_default( + self, mock_get_client, mock_settings + ): + """Test that optimization uses fast_model by default.""" + mock_settings.fast_model = "gpt-4o-mini" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "optimized" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + await optimize_query_for_vector_search("test query") + + mock_get_client.assert_called_once_with("gpt-4o-mini") + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_empty_input(self, mock_get_client): + """Test optimization with empty or None input.""" + # Test empty string + result = await optimize_query_for_vector_search("") + assert result == "" + mock_get_client.assert_not_called() + + # Test None + result = await optimize_query_for_vector_search(None) + assert result is None + mock_get_client.assert_not_called() + + # Test whitespace only + result = await optimize_query_for_vector_search(" ") + assert result == " " + mock_get_client.assert_not_called() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_client_error_fallback(self, mock_get_client): + """Test fallback to original query when client fails.""" + mock_get_client.side_effect = Exception("Model client error") + + original_query = "What are my preferences?" + result = await optimize_query_for_vector_search(original_query) + + assert result == original_query + mock_get_client.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_empty_response_fallback(self, mock_get_client): + """Test fallback when model returns empty response.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "" # Empty response + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + original_query = "What are my preferences?" + result = await optimize_query_for_vector_search(original_query) + + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_short_response_fallback(self, mock_get_client): + """Test fallback when model returns very short response.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "a" # Too short + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + original_query = "What are my preferences?" + result = await optimize_query_for_vector_search(original_query) + + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_no_choices_fallback(self, mock_get_client): + """Test fallback when model response has no choices.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [] # No choices + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + original_query = "What are my preferences?" + result = await optimize_query_for_vector_search(original_query) + + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_different_response_formats(self, mock_get_client): + """Test handling different response formats (text vs message).""" + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + + # Test with 'text' attribute + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + del mock_response.choices[0].message # Remove message attribute + mock_response.choices[0].text = "optimized via text" + mock_client.create_chat_completion.return_value = mock_response + + result = await optimize_query_for_vector_search("test query") + assert result == "optimized via text" + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimize_query_strips_whitespace(self, mock_get_client): + """Test that optimization strips whitespace from response.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = " optimized query \n" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + result = await optimize_query_for_vector_search("test query") + assert result == "optimized query" + + async def test_optimize_query_prompt_format(self): + """Test that the optimization prompt is correctly formatted.""" + with patch("agent_memory_server.llms.get_model_client") as mock_get_client: + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "optimized" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + test_query = "Can you tell me about user preferences?" + await optimize_query_for_vector_search(test_query) + + # Check that the prompt contains our test query + call_args = mock_client.create_chat_completion.call_args + prompt = call_args[1]["prompt"] + assert test_query in prompt + assert "semantic search" in prompt + assert "Guidelines:" in prompt + assert "Optimized query:" in prompt diff --git a/tests/test_long_term_memory.py b/tests/test_long_term_memory.py index 5c3d806..947d2f2 100644 --- a/tests/test_long_term_memory.py +++ b/tests/test_long_term_memory.py @@ -112,6 +112,7 @@ async def test_search_memories(self, mock_openai_client, mock_async_redis_client results = await search_long_term_memories( query, session_id=session_id, + optimize_query=False, # Disable query optimization for this unit test ) # Check that the adapter search_memories was called with the right arguments @@ -882,3 +883,183 @@ async def test_deduplicate_by_id_with_user_id_real_redis_error( # Re-raise to see the full traceback raise + + +@pytest.mark.asyncio +class TestSearchQueryOptimization: + """Test query optimization in search_long_term_memories function.""" + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_with_query_optimization_enabled( + self, mock_optimize, mock_get_adapter + ): + """Test that query optimization is applied when optimize_query=True.""" + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=1, + memories=[ + MemoryRecordResult( + id="test-id", + text="Test memory", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + ) + mock_get_adapter.return_value = mock_adapter + + # Mock query optimization + mock_optimize.return_value = "optimized search query" + + # Call search with optimization enabled + result = await search_long_term_memories( + text="tell me about my preferences", optimize_query=True, limit=10 + ) + + # Verify optimization was called + mock_optimize.assert_called_once_with("tell me about my preferences") + + # Verify adapter was called with optimized query + mock_adapter.search_memories.assert_called_once() + call_kwargs = mock_adapter.search_memories.call_args[1] + assert call_kwargs["query"] == "optimized search query" + + # Verify results + assert result.total == 1 + assert len(result.memories) == 1 + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_with_query_optimization_disabled( + self, mock_optimize, mock_get_adapter + ): + """Test that query optimization is skipped when optimize_query=False.""" + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=1, + memories=[ + MemoryRecordResult( + id="test-id", + text="Test memory", + memory_type=MemoryTypeEnum.SEMANTIC, + dist=0.1, + ) + ], + ) + mock_get_adapter.return_value = mock_adapter + + # Call search with optimization disabled + result = await search_long_term_memories( + text="tell me about my preferences", optimize_query=False, limit=10 + ) + + # Verify optimization was NOT called + mock_optimize.assert_not_called() + + # Verify adapter was called with original query + mock_adapter.search_memories.assert_called_once() + call_kwargs = mock_adapter.search_memories.call_args[1] + assert call_kwargs["query"] == "tell me about my preferences" + + # Verify results + assert result.total == 1 + assert len(result.memories) == 1 + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_with_empty_query_skips_optimization( + self, mock_optimize, mock_get_adapter + ): + """Test that empty queries skip optimization.""" + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=0, memories=[] + ) + mock_get_adapter.return_value = mock_adapter + + # Call search with empty query + await search_long_term_memories(text="", optimize_query=True, limit=10) + + # Verify optimization was NOT called for empty query + mock_optimize.assert_not_called() + + # Verify adapter was called with empty query + mock_adapter.search_memories.assert_called_once() + call_kwargs = mock_adapter.search_memories.call_args[1] + assert call_kwargs["query"] == "" + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_optimization_failure_fallback( + self, mock_optimize, mock_get_adapter + ): + """Test that search continues with original query if optimization fails.""" + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=0, memories=[] + ) + mock_get_adapter.return_value = mock_adapter + + # Mock optimization to return original query (simulating internal error handling) + mock_optimize.return_value = ( + "test query" # Returns original query after internal error handling + ) + + # Call search - this should not raise an exception + await search_long_term_memories( + text="test query", optimize_query=True, limit=10 + ) + + # Verify optimization was attempted + mock_optimize.assert_called_once_with("test query") + + # Verify search proceeded with the query (original after fallback) + mock_adapter.search_memories.assert_called_once() + call_kwargs = mock_adapter.search_memories.call_args[1] + assert call_kwargs["query"] == "test query" + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_passes_all_parameters_correctly( + self, mock_optimize, mock_get_adapter + ): + """Test that all search parameters are passed correctly to the adapter.""" + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=0, memories=[] + ) + mock_get_adapter.return_value = mock_adapter + + # Mock query optimization + mock_optimize.return_value = "optimized query" + + # Create filter objects for testing + session_filter = SessionId(eq="test-session") + + # Call search with various parameters + await search_long_term_memories( + text="test query", + session_id=session_filter, + limit=20, + offset=10, + distance_threshold=0.3, + optimize_query=True, + ) + + # Verify optimization was called + mock_optimize.assert_called_once_with("test query") + + # Verify all parameters were passed to adapter + mock_adapter.search_memories.assert_called_once() + call_kwargs = mock_adapter.search_memories.call_args[1] + assert call_kwargs["query"] == "optimized query" + assert call_kwargs["session_id"] == session_filter + assert call_kwargs["limit"] == 20 + assert call_kwargs["offset"] == 10 + assert call_kwargs["distance_threshold"] == 0.3 diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b56ff6e..95b84a6 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -180,7 +180,7 @@ async def test_default_namespace_injection(self, monkeypatch): # Capture injected namespace injected = {} - async def fake_core_search(payload): + async def fake_core_search(payload, optimize_query=False): injected["namespace"] = payload.namespace.eq if payload.namespace else None # Return a dummy result with total>0 to skip fake fallback return MemoryRecordResults( @@ -231,7 +231,9 @@ async def test_memory_prompt_parameter_passing(self, session, monkeypatch): # Capture the parameters passed to core_memory_prompt captured_params = {} - async def mock_core_memory_prompt(params: MemoryPromptRequest): + async def mock_core_memory_prompt( + params: MemoryPromptRequest, optimize_query: bool = False + ): captured_params["query"] = params.query captured_params["session"] = params.session captured_params["long_term_search"] = params.long_term_search @@ -468,3 +470,123 @@ async def test_mcp_lenient_memory_record_defaults(self, session, mcp_test_setup) extracted_memory.discrete_memory_extracted == "t" ), f"ExtractedMemoryRecord should default to 't', got '{extracted_memory.discrete_memory_extracted}'" assert extracted_memory.memory_type.value == "semantic" + + @pytest.mark.asyncio + async def test_search_long_term_memory_with_optimize_query_false_default( + self, session, mcp_test_setup + ): + """Test that MCP search_long_term_memory uses optimize_query=False by default.""" + async with client_session(mcp_app._mcp_server) as client: + with mock.patch( + "agent_memory_server.mcp.core_search_long_term_memory" + ) as mock_search: + mock_search.return_value = MemoryRecordResults(total=0, memories=[]) + + # Call search without optimize_query parameter + await client.call_tool( + "search_long_term_memory", {"text": "tell me about my preferences"} + ) + + # Verify search was called with optimize_query=False (MCP default) + mock_search.assert_called_once() + call_args = mock_search.call_args + # Check the SearchRequest object passed to mock_search + call_args[0][0] # First positional argument + # The optimize_query parameter should be passed separately + optimize_query = call_args[1]["optimize_query"] + assert optimize_query is False + + @pytest.mark.asyncio + async def test_search_long_term_memory_with_optimize_query_true_explicit( + self, session, mcp_test_setup + ): + """Test that MCP search_long_term_memory can use optimize_query=True when explicitly set.""" + async with client_session(mcp_app._mcp_server) as client: + with mock.patch( + "agent_memory_server.mcp.core_search_long_term_memory" + ) as mock_search: + mock_search.return_value = MemoryRecordResults(total=0, memories=[]) + + # Call search with explicit optimize_query=True + await client.call_tool( + "search_long_term_memory", + {"text": "tell me about my preferences", "optimize_query": True}, + ) + + # Verify search was called with optimize_query=True + mock_search.assert_called_once() + call_args = mock_search.call_args + optimize_query = call_args[1]["optimize_query"] + assert optimize_query is True + + @pytest.mark.asyncio + async def test_search_long_term_memory_with_optimize_query_false_explicit( + self, session, mcp_test_setup + ): + """Test that MCP search_long_term_memory can use optimize_query=False when explicitly set.""" + async with client_session(mcp_app._mcp_server) as client: + with mock.patch( + "agent_memory_server.mcp.core_search_long_term_memory" + ) as mock_search: + mock_search.return_value = MemoryRecordResults(total=0, memories=[]) + + # Call search with explicit optimize_query=False + await client.call_tool( + "search_long_term_memory", + {"text": "what are my UI preferences", "optimize_query": False}, + ) + + # Verify search was called with optimize_query=False + mock_search.assert_called_once() + call_args = mock_search.call_args + optimize_query = call_args[1]["optimize_query"] + assert optimize_query is False + + @pytest.mark.asyncio + async def test_memory_prompt_with_optimize_query_false_default( + self, session, mcp_test_setup + ): + """Test that MCP memory_prompt uses optimize_query=False by default.""" + async with client_session(mcp_app._mcp_server) as client: + with mock.patch( + "agent_memory_server.mcp.core_memory_prompt" + ) as mock_prompt: + mock_prompt.return_value = MemoryPromptResponse( + messages=[SystemMessage(content="Test response")] + ) + + # Call memory prompt without optimize_query parameter + await client.call_tool( + "memory_prompt", {"query": "what are my preferences?"} + ) + + # Verify memory_prompt was called with optimize_query=False (MCP default) + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + optimize_query = call_args[1]["optimize_query"] + assert optimize_query is False + + @pytest.mark.asyncio + async def test_memory_prompt_with_optimize_query_true_explicit( + self, session, mcp_test_setup + ): + """Test that MCP memory_prompt can use optimize_query=True when explicitly set.""" + async with client_session(mcp_app._mcp_server) as client: + with mock.patch( + "agent_memory_server.mcp.core_memory_prompt" + ) as mock_prompt: + mock_prompt.return_value = MemoryPromptResponse( + messages=[SystemMessage(content="Test response")] + ) + + # Call memory prompt with explicit optimize_query=True + await client.call_tool( + "memory_prompt", + {"query": "what are my preferences?", "optimize_query": True}, + ) + + # Verify memory_prompt was called with optimize_query=True + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + optimize_query = call_args[1]["optimize_query"] + assert optimize_query is True diff --git a/tests/test_query_optimization_errors.py b/tests/test_query_optimization_errors.py new file mode 100644 index 0000000..f5ef916 --- /dev/null +++ b/tests/test_query_optimization_errors.py @@ -0,0 +1,219 @@ +""" +Test error handling and edge cases for query optimization feature. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agent_memory_server.llms import optimize_query_for_vector_search +from agent_memory_server.long_term_memory import search_long_term_memories +from agent_memory_server.models import MemoryRecordResults + + +@pytest.mark.asyncio +class TestQueryOptimizationErrorHandling: + """Test error handling scenarios for query optimization.""" + + VERY_LONG_QUERY_REPEAT_COUNT = 1000 + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_network_timeout(self, mock_get_client): + """Test graceful fallback when model API times out.""" + # Simulate network timeout + mock_client = AsyncMock() + mock_client.create_chat_completion.side_effect = TimeoutError( + "Request timed out" + ) + mock_get_client.return_value = mock_client + + original_query = "Can you tell me about my settings?" + result = await optimize_query_for_vector_search(original_query) + + # Should fall back to original query + assert result == original_query + mock_get_client.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_invalid_api_key(self, mock_get_client): + """Test fallback when API key is invalid.""" + # Simulate authentication error + mock_get_client.side_effect = Exception("Invalid API key") + + original_query = "What are my preferences?" + result = await optimize_query_for_vector_search(original_query) + + # Should fall back to original query + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_malformed_response(self, mock_get_client): + """Test handling of malformed model responses.""" + mock_client = AsyncMock() + mock_response = MagicMock() + # Malformed response - no choices attribute + if hasattr(mock_response, "choices"): + del mock_response.choices + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + original_query = "Find my user settings" + # The function should handle AttributeError gracefully and fall back + try: + result = await optimize_query_for_vector_search(original_query) + except AttributeError: + pytest.fail( + "optimize_query_for_vector_search did not handle missing choices attribute gracefully" + ) + + # Should fall back to original query + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_none_response(self, mock_get_client): + """Test handling when model returns None.""" + mock_client = AsyncMock() + mock_client.create_chat_completion.return_value = None + mock_get_client.return_value = mock_client + + original_query = "Show my preferences" + result = await optimize_query_for_vector_search(original_query) + + # Should fall back to original query + assert result == original_query + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_unicode_query(self, mock_get_client): + """Test optimization with unicode and special characters.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "préférences utilisateur émojis 🎉" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + unicode_query = "Mes préférences avec émojis 🎉 et caractères spéciaux" + result = await optimize_query_for_vector_search(unicode_query) + + assert result == "préférences utilisateur émojis 🎉" + mock_get_client.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_very_long_query(self, mock_get_client): + """Test optimization with extremely long queries.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "long query optimized" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + # Create a very long query (10,000 characters) + long_query = ( + "Tell me about " + + "preferences " * self.VERY_LONG_QUERY_REPEAT_COUNT + + "settings" + ) + result = await optimize_query_for_vector_search(long_query) + + assert result == "long query optimized" + mock_get_client.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_preserves_query_intent(self, mock_get_client): + """Test that optimization preserves the core intent of queries.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + # Mock an optimization that maintains intent + mock_response.choices[0].message.content = "user interface dark mode settings" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + original_query = ( + "Can you please tell me about my dark mode settings for the UI?" + ) + result = await optimize_query_for_vector_search(original_query) + + assert result == "user interface dark mode settings" + # Verify the prompt includes the original query + call_args = mock_client.create_chat_completion.call_args + prompt = call_args[1]["prompt"] + assert original_query in prompt + + @patch("agent_memory_server.long_term_memory.get_vectorstore_adapter") + @patch("agent_memory_server.long_term_memory.optimize_query_for_vector_search") + async def test_search_continues_when_optimization_fails( + self, mock_optimize, mock_get_adapter + ): + """Test that search continues even if optimization completely fails.""" + # Mock optimization to return original query (simulating internal error handling) + mock_optimize.return_value = ( + "test query" # The function handles errors internally + ) + + # Mock the vectorstore adapter + mock_adapter = AsyncMock() + mock_adapter.search_memories.return_value = MemoryRecordResults( + total=0, memories=[] + ) + mock_get_adapter.return_value = mock_adapter + + # This should not raise an exception + await search_long_term_memories( + text="test query", optimize_query=True, limit=10 + ) + + # Verify optimization was attempted + mock_optimize.assert_called_once() + # Verify search still proceeded + mock_adapter.search_memories.assert_called_once() + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_handles_special_characters_in_response( + self, mock_get_client + ): + """Test handling of special characters and formatting in model responses.""" + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + # Response with various formatting that should be cleaned + mock_response.choices[ + 0 + ].message.content = "\n\n **user preferences settings** \n\n" + mock_client.create_chat_completion.return_value = mock_response + mock_get_client.return_value = mock_client + + result = await optimize_query_for_vector_search("What are my settings?") + + # Should strip whitespace but preserve the content + assert result == "**user preferences settings**" + + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_model_rate_limit(self, mock_get_client): + """Test fallback when model API is rate limited.""" + # Simulate rate limit error + mock_get_client.side_effect = Exception("Rate limit exceeded") + + original_query = "Find my account settings" + result = await optimize_query_for_vector_search(original_query) + + # Should fall back to original query + assert result == original_query + + @patch("agent_memory_server.llms.settings") + @patch("agent_memory_server.llms.get_model_client") + async def test_optimization_with_invalid_model_name( + self, mock_get_client, mock_settings + ): + """Test handling of invalid/unavailable model names.""" + # Set an invalid model name + mock_settings.fast_model = "invalid-model-name" + mock_get_client.side_effect = Exception("Model not found") + + original_query = "Show user preferences" + result = await optimize_query_for_vector_search(original_query) + + # Should fall back to original query + assert result == original_query + mock_get_client.assert_called_once_with("invalid-model-name")