Skip to content

Feat: Add query optimization for vector search with configurable models #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions agent-memory-client/agent_memory_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,13 @@ async def search_long_term_memory(
memory_type: MemoryType | dict[str, Any] | None = None,
limit: int = 10,
offset: int = 0,
optimize_query: bool = True,
) -> MemoryRecordResults:
"""
Search long-term memories using semantic search and filters.

Args:
text: Search query text for semantic similarity
text: Query for vector search - will be used for semantic similarity matching
session_id: Optional session ID filter
namespace: Optional namespace filter
topics: Optional topics filter
Expand All @@ -591,6 +592,7 @@ async def search_long_term_memory(
memory_type: Optional memory type filter
limit: Maximum number of results to return (default: 10)
offset: Offset for pagination (default: 0)
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
MemoryRecordResults with matching memories and metadata
Expand Down Expand Up @@ -669,10 +671,14 @@ async def search_long_term_memory(
if distance_threshold is not None:
payload["distance_threshold"] = distance_threshold

# Add optimize_query as query parameter
params = {"optimize_query": str(optimize_query).lower()}

try:
response = await self._client.post(
"/v1/long-term-memory/search",
json=payload,
params=params,
)
response.raise_for_status()
return MemoryRecordResults(**response.json())
Expand All @@ -691,6 +697,7 @@ async def search_memory_tool(
max_results: int = 5,
min_relevance: float | None = None,
user_id: str | None = None,
optimize_query: bool = False,
) -> dict[str, Any]:
"""
Simplified long-term memory search designed for LLM tool use.
Expand All @@ -701,13 +708,14 @@ async def search_memory_tool(
searches long-term memory, not working memory.

Args:
query: The search query text
query: The query for vector search
topics: Optional list of topic strings to filter by
entities: Optional list of entity strings to filter by
memory_type: Optional memory type ("episodic", "semantic", "message")
max_results: Maximum results to return (default: 5)
min_relevance: Optional minimum relevance score (0.0-1.0)
user_id: Optional user ID to filter memories by
optimize_query: Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries)

Returns:
Dict with 'memories' list and 'summary' for LLM consumption
Expand Down Expand Up @@ -759,6 +767,7 @@ async def search_memory_tool(
distance_threshold=distance_threshold,
limit=max_results,
user_id=user_id_filter,
optimize_query=optimize_query,
)

# Format for LLM consumption
Expand Down Expand Up @@ -828,13 +837,13 @@ async def handle_tool_calls(client, tool_calls):
"type": "function",
"function": {
"name": "search_memory",
"description": "Search long-term memory for relevant information based on a query. Use this when you need to recall past conversations, user preferences, or previously stored information. Note: This searches only long-term memory, not current working memory.",
"description": "Search long-term memory for relevant information using a query for vector search. Use this when you need to recall past conversations, user preferences, or previously stored information. Note: This searches only long-term memory, not current working memory.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query describing what information you're looking for",
"description": "The query for vector search describing what information you're looking for",
},
"topics": {
"type": "array",
Expand Down Expand Up @@ -868,6 +877,11 @@ async def handle_tool_calls(client, tool_calls):
"type": "string",
"description": "Optional user ID to filter memories by (e.g., 'user123')",
},
"optimize_query": {
"type": "boolean",
"default": False,
"description": "Whether to optimize the query for vector search (default: False - LLMs typically provide already optimized queries)",
},
},
"required": ["query"],
},
Expand Down Expand Up @@ -2138,20 +2152,22 @@ async def memory_prompt(
context_window_max: int | None = None,
long_term_search: dict[str, Any] | None = None,
user_id: str | None = None,
optimize_query: bool = True,
) -> dict[str, Any]:
"""
Hydrate a user query with memory context and return a prompt ready to send to an LLM.

NOTE: `long_term_search` uses the same filter options as `search_long_term_memories`.

Args:
query: The input text to find relevant context for
query: The query for vector search to find relevant context for
session_id: Optional session ID to include session messages
namespace: Optional namespace for the session
model_name: Optional model name to determine context window size
context_window_max: Optional direct specification of context window tokens
long_term_search: Optional search parameters for long-term memory
user_id: Optional user ID for the session
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
Dict with messages hydrated with relevant memory context
Expand Down Expand Up @@ -2208,10 +2224,14 @@ async def memory_prompt(
}
payload["long_term_search"] = long_term_search

# Add optimize_query as query parameter
params = {"optimize_query": str(optimize_query).lower()}

try:
response = await self._client.post(
"/v1/memory/prompt",
json=payload,
params=params,
)
response.raise_for_status()
result = response.json()
Expand All @@ -2235,6 +2255,7 @@ async def hydrate_memory_prompt(
distance_threshold: float | None = None,
memory_type: dict[str, Any] | None = None,
limit: int = 10,
optimize_query: bool = True,
) -> dict[str, Any]:
"""
Hydrate a user query with long-term memory context using filters.
Expand All @@ -2243,7 +2264,7 @@ async def hydrate_memory_prompt(
long-term memory search with the specified filters.

Args:
query: The input text to find relevant context for
query: The query for vector search to find relevant context for
session_id: Optional session ID filter (as dict)
namespace: Optional namespace filter (as dict)
topics: Optional topics filter (as dict)
Expand All @@ -2254,6 +2275,7 @@ async def hydrate_memory_prompt(
distance_threshold: Optional distance threshold
memory_type: Optional memory type filter (as dict)
limit: Maximum number of long-term memories to include
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
Dict with messages hydrated with relevant long-term memories
Expand Down Expand Up @@ -2285,6 +2307,7 @@ async def hydrate_memory_prompt(
return await self.memory_prompt(
query=query,
long_term_search=long_term_search,
optimize_query=optimize_query,
)

def _deep_merge_dicts(
Expand Down
8 changes: 7 additions & 1 deletion agent_memory_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,13 +494,15 @@ async def create_long_term_memory(
@router.post("/v1/long-term-memory/search", response_model=MemoryRecordResultsResponse)
async def search_long_term_memory(
payload: SearchRequest,
optimize_query: bool = True,
current_user: UserInfo = Depends(get_current_user),
):
"""
Run a semantic search on long-term memory with filtering options.

Args:
payload: Search payload with filter objects for precise queries
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
List of search results
Expand All @@ -517,6 +519,7 @@ async def search_long_term_memory(
"distance_threshold": payload.distance_threshold,
"limit": payload.limit,
"offset": payload.offset,
"optimize_query": optimize_query,
**filters,
}

Expand Down Expand Up @@ -549,13 +552,14 @@ async def delete_long_term_memory(
@router.post("/v1/memory/prompt", response_model=MemoryPromptResponse)
async def memory_prompt(
params: MemoryPromptRequest,
optimize_query: bool = True,
current_user: UserInfo = Depends(get_current_user),
) -> MemoryPromptResponse:
"""
Hydrate a user query with memory context and return a prompt
ready to send to an LLM.

`query` is the input text that the caller of this API wants to use to find
`query` is the query for vector search that the caller of this API wants to use to find
relevant context. If `session_id` is provided and matches an existing
session, the resulting prompt will include those messages as the immediate
history of messages leading to a message containing `query`.
Expand All @@ -566,6 +570,7 @@ async def memory_prompt(

Args:
params: MemoryPromptRequest
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
List of messages to send to an LLM, hydrated with relevant memory context
Expand Down Expand Up @@ -671,6 +676,7 @@ async def memory_prompt(
logger.debug(f"[memory_prompt] Search payload: {search_payload}")
long_term_memories = await search_long_term_memory(
search_payload,
optimize_query=optimize_query,
)

logger.debug(f"[memory_prompt] Long-term memories: {long_term_memories}")
Expand Down
6 changes: 6 additions & 0 deletions agent_memory_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class Settings(BaseSettings):
anthropic_api_base: str | None = None
generation_model: str = "gpt-4o"
embedding_model: str = "text-embedding-3-small"

# Model selection for query optimization
slow_model: str = "gpt-4o" # Slower, more capable model for complex tasks
fast_model: str = (
"gpt-4o-mini" # Faster, smaller model for quick tasks like query optimization
)
port: int = 8000
mcp_port: int = 9000

Expand Down
74 changes: 74 additions & 0 deletions agent_memory_server/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,77 @@ async def get_model_client(
raise ValueError(f"Unsupported model provider: {model_config.provider}")

return _model_clients[model_name]


async def optimize_query_for_vector_search(
query: str,
model_name: str | None = None,
) -> str:
"""
Optimize a user query for vector search using a fast model.

This function takes a natural language query and rewrites it to be more effective
for semantic similarity search. It uses a fast, small model to improve search
performance while maintaining query intent.

Args:
query: The original user query to optimize
model_name: Model to use for optimization (defaults to settings.fast_model)

Returns:
Optimized query string better suited for vector search
"""
if not query or not query.strip():
return query

# Use fast model from settings if not specified
effective_model = model_name or settings.fast_model

# Create optimization prompt
optimization_prompt = f"""Transform this natural language query into an optimized version for semantic search. The goal is to make it more effective for finding semantically similar content while preserving the original intent.

Guidelines:
- Keep the core meaning and intent
- Use more specific and descriptive terms
- Remove unnecessary words like "tell me", "I want to know", "can you"
- Focus on the key concepts and topics
- Make it concise but comprehensive

Original query: {query}

Optimized query:"""

try:
client = await get_model_client(effective_model)

response = await client.create_chat_completion(
model=effective_model,
prompt=optimization_prompt,
)

if response.choices and len(response.choices) > 0:
optimized = ""
if hasattr(response.choices[0], "message"):
optimized = response.choices[0].message.content
elif hasattr(response.choices[0], "text"):
optimized = response.choices[0].text
else:
optimized = str(response.choices[0])

# Clean up the response
optimized = optimized.strip()

# Fallback to original if optimization failed
if not optimized or len(optimized) < 2:
logger.warning(f"Query optimization failed for: {query}")
return query

logger.debug(f"Optimized query: '{query}' -> '{optimized}'")
return optimized

except Exception as e:
logger.warning(f"Failed to optimize query '{query}': {e}")
# Return original query if optimization fails
return query

return query
13 changes: 10 additions & 3 deletions agent_memory_server/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AnthropicClientWrapper,
OpenAIClientWrapper,
get_model_client,
optimize_query_for_vector_search,
)
from agent_memory_server.models import (
ExtractedMemoryRecord,
Expand Down Expand Up @@ -718,13 +719,13 @@ async def search_long_term_memories(
memory_hash: MemoryHash | None = None,
limit: int = 10,
offset: int = 0,
optimize_query: bool = True,
) -> MemoryRecordResults:
"""
Search for long-term memories using the pluggable VectorStore adapter.

Args:
text: Search query text
redis: Redis client (kept for compatibility but may be unused depending on backend)
text: Query for vector search - will be used for semantic similarity matching
session_id: Optional session ID filter
user_id: Optional user ID filter
namespace: Optional namespace filter
Expand All @@ -738,16 +739,22 @@ async def search_long_term_memories(
memory_hash: Optional memory hash filter
limit: Maximum number of results
offset: Offset for pagination
optimize_query: Whether to optimize the query for vector search using a fast model (default: True)

Returns:
MemoryRecordResults containing matching memories
"""
# Optimize query for vector search if requested
search_query = text
if optimize_query and text:
search_query = await optimize_query_for_vector_search(text)

# Get the VectorStore adapter
adapter = await get_vectorstore_adapter()

# Delegate search to the adapter
return await adapter.search_memories(
query=text,
query=search_query,
session_id=session_id,
user_id=user_id,
namespace=namespace,
Expand Down
Loading