Skip to content

Commit ee4863a

Browse files
committed
Add support for the memory prompt in client SDK
1 parent f4c295a commit ee4863a

File tree

4 files changed

+413
-69
lines changed

4 files changed

+413
-69
lines changed

agent_memory_server/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ async def memory_prompt(params: MemoryPromptRequest) -> MemoryPromptResponse:
278278
SystemMessage(
279279
content=TextContent(
280280
type="text",
281-
text=f"## A summary of the conversation so far\n{session_memory.context}",
281+
text=f"## A summary of the conversation so far:\n{session_memory.context}",
282282
),
283283
)
284284
)

agent_memory_server/client/api.py

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
HealthCheckResponse,
2626
LongTermMemory,
2727
LongTermMemoryResults,
28+
MemoryPromptRequest,
29+
MemoryPromptResponse,
2830
SearchRequest,
2931
SessionListResponse,
3032
SessionMemory,
33+
SessionMemoryRequest,
3134
SessionMemoryResponse,
3235
)
3336

@@ -129,8 +132,8 @@ async def list_sessions(
129132
SessionListResponse containing session IDs and total count
130133
"""
131134
params = {
132-
"limit": limit,
133-
"offset": offset,
135+
"limit": str(limit),
136+
"offset": str(offset),
134137
}
135138
if namespace is not None:
136139
params["namespace"] = namespace
@@ -343,6 +346,201 @@ async def search_long_term_memory(
343346
response.raise_for_status()
344347
return LongTermMemoryResults(**response.json())
345348

349+
async def memory_prompt(
350+
self,
351+
query: str,
352+
session_id: str | None = None,
353+
namespace: str | None = None,
354+
window_size: int | None = None,
355+
model_name: ModelNameLiteral | None = None,
356+
context_window_max: int | None = None,
357+
long_term_search: SearchRequest | None = None,
358+
) -> MemoryPromptResponse:
359+
"""
360+
Hydrate a user query with memory context and return a prompt
361+
ready to send to an LLM.
362+
363+
This method can retrieve relevant session history and long-term memories
364+
to provide context for the query.
365+
366+
Args:
367+
query: The user's query text
368+
session_id: Optional session ID to retrieve history from
369+
namespace: Optional namespace for session and long-term memories
370+
window_size: Optional number of messages to include from session history
371+
model_name: Optional model name to determine context window size
372+
context_window_max: Optional direct specification of context window max tokens
373+
long_term_search: Optional SearchRequest for specific long-term memory filtering
374+
375+
Returns:
376+
MemoryPromptResponse containing a list of messages with context
377+
378+
Raises:
379+
httpx.HTTPStatusError: If the request fails or if neither session_id nor long_term_search is provided
380+
"""
381+
# Prepare the request payload
382+
session_params = None
383+
if session_id is not None:
384+
session_params = SessionMemoryRequest(
385+
session_id=session_id,
386+
namespace=namespace or self.config.default_namespace,
387+
window_size=window_size or 12, # Default from settings
388+
model_name=model_name,
389+
context_window_max=context_window_max,
390+
)
391+
392+
# If no explicit long_term_search is provided but we have a query, create a basic one
393+
if long_term_search is None and query:
394+
# Use default namespace from config if none provided
395+
_namespace = None
396+
if namespace is not None:
397+
_namespace = Namespace(eq=namespace)
398+
elif self.config.default_namespace is not None:
399+
_namespace = Namespace(eq=self.config.default_namespace)
400+
401+
long_term_search = SearchRequest(
402+
text=query,
403+
namespace=_namespace,
404+
)
405+
406+
# Create the request payload
407+
payload = MemoryPromptRequest(
408+
query=query,
409+
session=session_params,
410+
long_term_search=long_term_search,
411+
)
412+
413+
# Make the API call
414+
response = await self._client.post(
415+
"/memory-prompt", json=payload.model_dump(exclude_none=True)
416+
)
417+
response.raise_for_status()
418+
data = response.json()
419+
return MemoryPromptResponse(**data)
420+
421+
async def hydrate_memory_prompt(
422+
self,
423+
query: str,
424+
session_id: SessionId | dict[str, Any] | None = None,
425+
namespace: Namespace | dict[str, Any] | None = None,
426+
topics: Topics | dict[str, Any] | None = None,
427+
entities: Entities | dict[str, Any] | None = None,
428+
created_at: CreatedAt | dict[str, Any] | None = None,
429+
last_accessed: LastAccessed | dict[str, Any] | None = None,
430+
user_id: UserId | dict[str, Any] | None = None,
431+
distance_threshold: float | None = None,
432+
memory_type: MemoryType | dict[str, Any] | None = None,
433+
limit: int = 10,
434+
offset: int = 0,
435+
window_size: int = 12,
436+
model_name: ModelNameLiteral | None = None,
437+
context_window_max: int | None = None,
438+
) -> MemoryPromptResponse:
439+
"""
440+
Hydrate a user query with relevant session history and long-term memories.
441+
442+
This method enriches the user's query by retrieving:
443+
1. Context from the conversation session (if session_id is provided)
444+
2. Relevant long-term memories related to the query
445+
446+
Args:
447+
query: The user's query text
448+
session_id: Optional filter for session ID
449+
namespace: Optional filter for namespace
450+
topics: Optional filter for topics in long-term memories
451+
entities: Optional filter for entities in long-term memories
452+
created_at: Optional filter for creation date
453+
last_accessed: Optional filter for last access date
454+
user_id: Optional filter for user ID
455+
distance_threshold: Optional distance threshold for semantic search
456+
memory_type: Optional filter for memory type
457+
limit: Maximum number of long-term memory results (default: 10)
458+
offset: Offset for pagination (default: 0)
459+
window_size: Number of messages to include from session history (default: 12)
460+
model_name: Optional model name to determine context window size
461+
context_window_max: Optional direct specification of context window max tokens
462+
463+
Returns:
464+
MemoryPromptResponse containing a list of messages with context
465+
466+
Raises:
467+
httpx.HTTPStatusError: If the request fails
468+
"""
469+
# Convert dictionary filters to their proper filter objects if needed
470+
if isinstance(session_id, dict):
471+
session_id = SessionId(**session_id)
472+
if isinstance(namespace, dict):
473+
namespace = Namespace(**namespace)
474+
if isinstance(topics, dict):
475+
topics = Topics(**topics)
476+
if isinstance(entities, dict):
477+
entities = Entities(**entities)
478+
if isinstance(created_at, dict):
479+
created_at = CreatedAt(**created_at)
480+
if isinstance(last_accessed, dict):
481+
last_accessed = LastAccessed(**last_accessed)
482+
if isinstance(user_id, dict):
483+
user_id = UserId(**user_id)
484+
if isinstance(memory_type, dict):
485+
memory_type = MemoryType(**memory_type)
486+
487+
# Apply default namespace if needed and no namespace filter specified
488+
if namespace is None and self.config.default_namespace is not None:
489+
namespace = Namespace(eq=self.config.default_namespace)
490+
491+
# Extract session_id value if it exists
492+
session_params = None
493+
_session_id = None
494+
if session_id and hasattr(session_id, "eq") and session_id.eq:
495+
_session_id = session_id.eq
496+
497+
if _session_id:
498+
# Get namespace value if it exists
499+
_namespace = None
500+
if namespace and hasattr(namespace, "eq"):
501+
_namespace = namespace.eq
502+
elif self.config.default_namespace:
503+
_namespace = self.config.default_namespace
504+
505+
session_params = SessionMemoryRequest(
506+
session_id=_session_id,
507+
namespace=_namespace,
508+
window_size=window_size,
509+
model_name=model_name,
510+
context_window_max=context_window_max,
511+
)
512+
513+
# Create search request for long-term memory
514+
search_payload = SearchRequest(
515+
text=query,
516+
session_id=session_id,
517+
namespace=namespace,
518+
topics=topics,
519+
entities=entities,
520+
created_at=created_at,
521+
last_accessed=last_accessed,
522+
user_id=user_id,
523+
distance_threshold=distance_threshold,
524+
memory_type=memory_type,
525+
limit=limit,
526+
offset=offset,
527+
)
528+
529+
# Create the request payload
530+
payload = MemoryPromptRequest(
531+
query=query,
532+
session=session_params,
533+
long_term_search=search_payload,
534+
)
535+
536+
# Make the API call
537+
response = await self._client.post(
538+
"/memory-prompt", json=payload.model_dump(exclude_none=True)
539+
)
540+
response.raise_for_status()
541+
data = response.json()
542+
return MemoryPromptResponse(**data)
543+
346544

347545
# Helper function to create a memory client
348546
async def create_memory_client(

agent_memory_server/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,6 @@ class MemoryPromptRequest(BaseModel):
304304
long_term_search: SearchRequest | None = None
305305

306306

307-
class MemoryPromptResponse(BaseModel):
308-
messages: list[base.Message]
309-
310-
311307
class SystemMessage(base.Message):
312308
"""A system message"""
313309

@@ -317,4 +313,8 @@ class SystemMessage(base.Message):
317313
class UserMessage(base.Message):
318314
"""A user message"""
319315

320-
role: Literal["system"] = "system"
316+
role: Literal["user"] = "user"
317+
318+
319+
class MemoryPromptResponse(BaseModel):
320+
messages: list[base.Message | SystemMessage]

0 commit comments

Comments
 (0)