Skip to content

Commit 6a5f3ce

Browse files
committed
Fix mypy issues
1 parent fdeb69e commit 6a5f3ce

File tree

13 files changed

+25
-273
lines changed

13 files changed

+25
-273
lines changed

agent-memory-client/agent_memory_client/client.py

Lines changed: 9 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HealthCheckResponse,
3535
MemoryRecord,
3636
MemoryRecordResults,
37+
MemoryTypeEnum,
3738
ModelNameLiteral,
3839
SessionListResponse,
3940
WorkingMemory,
@@ -442,7 +443,7 @@ async def add_memories_to_working_memory(
442443
# Auto-generate IDs for memories that don't have them
443444
for memory in final_memories:
444445
if not memory.id:
445-
memory.id = str(ulid.new())
446+
memory.id = str(ulid.ULID())
446447

447448
# Create new working memory with the memories
448449
working_memory = WorkingMemory(
@@ -617,136 +618,10 @@ async def search_long_term_memory(
617618
exclude_none=True, mode="json"
618619
)
619620
if user_id:
620-
payload["user_id"] = user_id.model_dump(exclude_none=True)
621-
if memory_type:
622-
payload["memory_type"] = memory_type.model_dump(exclude_none=True)
623-
if distance_threshold is not None:
624-
payload["distance_threshold"] = distance_threshold
625-
626-
try:
627-
response = await self._client.post(
628-
"/v1/long-term-memory/search",
629-
json=payload,
630-
)
631-
response.raise_for_status()
632-
return MemoryRecordResults(**response.json())
633-
except httpx.HTTPStatusError as e:
634-
self._handle_http_error(e.response)
635-
raise
636-
637-
async def search_memories(
638-
self,
639-
text: str,
640-
session_id: SessionId | dict[str, Any] | None = None,
641-
namespace: Namespace | dict[str, Any] | None = None,
642-
topics: Topics | dict[str, Any] | None = None,
643-
entities: Entities | dict[str, Any] | None = None,
644-
created_at: CreatedAt | dict[str, Any] | None = None,
645-
last_accessed: LastAccessed | dict[str, Any] | None = None,
646-
user_id: UserId | dict[str, Any] | None = None,
647-
distance_threshold: float | None = None,
648-
memory_type: MemoryType | dict[str, Any] | None = None,
649-
limit: int = 10,
650-
offset: int = 0,
651-
) -> MemoryRecordResults:
652-
"""
653-
Search across all memory types (working memory and long-term memory).
654-
655-
This method searches both working memory (ephemeral, session-scoped) and
656-
long-term memory (persistent, indexed) to provide comprehensive results.
657-
658-
For working memory:
659-
- Uses simple text matching
660-
- Searches across all sessions (unless session_id filter is provided)
661-
- Returns memories that haven't been promoted to long-term storage
662-
663-
For long-term memory:
664-
- Uses semantic vector search
665-
- Includes promoted memories from working memory
666-
- Supports advanced filtering by topics, entities, etc.
667-
668-
Args:
669-
text: Search query text for semantic similarity
670-
session_id: Optional session ID filter
671-
namespace: Optional namespace filter
672-
topics: Optional topics filter
673-
entities: Optional entities filter
674-
created_at: Optional creation date filter
675-
last_accessed: Optional last accessed date filter
676-
user_id: Optional user ID filter
677-
distance_threshold: Optional distance threshold for search results
678-
memory_type: Optional memory type filter
679-
limit: Maximum number of results to return (default: 10)
680-
offset: Offset for pagination (default: 0)
681-
682-
Returns:
683-
MemoryRecordResults with matching memories from both memory types
684-
685-
Raises:
686-
MemoryServerError: If the request fails
687-
688-
Example:
689-
```python
690-
# Search for user preferences with topic filtering
691-
from .filters import Topics
692-
693-
results = await client.search_memories(
694-
text="user prefers dark mode",
695-
topics=Topics(any=["preferences", "ui"]),
696-
limit=5
697-
)
698-
699-
for memory in results.memories:
700-
print(f"Found: {memory.text}")
701-
```
702-
"""
703-
# Convert dictionary filters to their proper filter objects if needed
704-
if isinstance(session_id, dict):
705-
session_id = SessionId(**session_id)
706-
if isinstance(namespace, dict):
707-
namespace = Namespace(**namespace)
708-
if isinstance(topics, dict):
709-
topics = Topics(**topics)
710-
if isinstance(entities, dict):
711-
entities = Entities(**entities)
712-
if isinstance(created_at, dict):
713-
created_at = CreatedAt(**created_at)
714-
if isinstance(last_accessed, dict):
715-
last_accessed = LastAccessed(**last_accessed)
716-
if isinstance(user_id, dict):
717-
user_id = UserId(**user_id)
718-
if isinstance(memory_type, dict):
719-
memory_type = MemoryType(**memory_type)
720-
721-
# Apply default namespace if needed and no namespace filter specified
722-
if namespace is None and self.config.default_namespace is not None:
723-
namespace = Namespace(eq=self.config.default_namespace)
724-
725-
payload = {
726-
"text": text,
727-
"limit": limit,
728-
"offset": offset,
729-
}
730-
731-
# Add filters if provided
732-
if session_id:
733-
payload["session_id"] = session_id.model_dump(exclude_none=True)
734-
if namespace:
735-
payload["namespace"] = namespace.model_dump(exclude_none=True)
736-
if topics:
737-
payload["topics"] = topics.model_dump(exclude_none=True)
738-
if entities:
739-
payload["entities"] = entities.model_dump(exclude_none=True)
740-
if created_at:
741-
payload["created_at"] = created_at.model_dump(
742-
exclude_none=True, mode="json"
743-
)
744-
if last_accessed:
745-
payload["last_accessed"] = last_accessed.model_dump(
746-
exclude_none=True, mode="json"
747-
)
748-
if user_id:
749-
payload["user_id"] = user_id.model_dump(exclude_none=True)
621+
if isinstance(user_id, dict):
622+
payload["user_id"] = user_id
623+
else:
624+
payload["user_id"] = user_id.model_dump(exclude_none=True)
750625
if memory_type:
751626
payload["memory_type"] = memory_type.model_dump(exclude_none=True)
752627
if distance_threshold is not None:
@@ -1076,7 +951,7 @@ async def add_memory_tool(
1076951
# Create memory record
1077952
memory = ClientMemoryRecord(
1078953
text=text,
1079-
memory_type=memory_type,
954+
memory_type=MemoryTypeEnum(memory_type),
1080955
topics=topics,
1081956
entities=entities,
1082957
namespace=namespace or self.config.default_namespace,
@@ -1111,7 +986,7 @@ async def update_memory_data_tool(
1111986
self,
1112987
session_id: str,
1113988
data: dict[str, Any],
1114-
merge_strategy: str = "merge",
989+
merge_strategy: Literal["replace", "merge", "deep_merge"] = "merge",
1115990
namespace: str | None = None,
1116991
user_id: str | None = None,
1117992
) -> dict[str, Any]:
@@ -1997,70 +1872,6 @@ async def search_all_long_term_memories(
19971872

19981873
offset += batch_size
19991874

2000-
async def search_all_memories(
2001-
self,
2002-
text: str,
2003-
session_id: SessionId | dict[str, Any] | None = None,
2004-
namespace: Namespace | dict[str, Any] | None = None,
2005-
topics: Topics | dict[str, Any] | None = None,
2006-
entities: Entities | dict[str, Any] | None = None,
2007-
created_at: CreatedAt | dict[str, Any] | None = None,
2008-
last_accessed: LastAccessed | dict[str, Any] | None = None,
2009-
user_id: UserId | dict[str, Any] | None = None,
2010-
distance_threshold: float | None = None,
2011-
memory_type: MemoryType | dict[str, Any] | None = None,
2012-
batch_size: int = 50,
2013-
) -> AsyncIterator[MemoryRecord]:
2014-
"""
2015-
Auto-paginating version of unified memory search.
2016-
2017-
Searches both working memory and long-term memory with automatic pagination.
2018-
2019-
Args:
2020-
text: Search query text
2021-
session_id: Optional session ID filter
2022-
namespace: Optional namespace filter
2023-
topics: Optional topics filter
2024-
entities: Optional entities filter
2025-
created_at: Optional creation date filter
2026-
last_accessed: Optional last accessed date filter
2027-
user_id: Optional user ID filter
2028-
distance_threshold: Optional distance threshold
2029-
memory_type: Optional memory type filter
2030-
batch_size: Number of results to fetch per API call
2031-
2032-
Yields:
2033-
Individual memory records from all result pages
2034-
"""
2035-
offset = 0
2036-
while True:
2037-
results = await self.search_memories(
2038-
text=text,
2039-
session_id=session_id,
2040-
namespace=namespace,
2041-
topics=topics,
2042-
entities=entities,
2043-
created_at=created_at,
2044-
last_accessed=last_accessed,
2045-
user_id=user_id,
2046-
distance_threshold=distance_threshold,
2047-
memory_type=memory_type,
2048-
limit=batch_size,
2049-
offset=offset,
2050-
)
2051-
2052-
if not results.memories:
2053-
break
2054-
2055-
for memory in results.memories:
2056-
yield memory
2057-
2058-
# If we got fewer results than batch_size, we've reached the end
2059-
if len(results.memories) < batch_size:
2060-
break
2061-
2062-
offset += batch_size
2063-
20641875
def validate_memory_record(self, memory: ClientMemoryRecord | MemoryRecord) -> None:
20651876
"""
20661877
Validate memory record before sending to server.
@@ -2237,7 +2048,7 @@ async def append_messages_to_working_memory(
22372048
converted_existing_messages.append(msg)
22382049
else:
22392050
# Fallback for any other message type - convert to string content
2240-
converted_existing_messages.append(
2051+
converted_existing_messages.append( # type: ignore
22412052
{"role": "user", "content": str(msg)}
22422053
)
22432054

agent-memory-client/agent_memory_client/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class ClientMemoryRecord(MemoryRecord):
122122
"""A memory record with a client-provided ID"""
123123

124124
id: str = Field(
125-
default_factory=lambda: str(ulid.new()),
125+
default_factory=lambda: str(ulid.ULID()),
126126
description="Client-provided ID generated by the client (ULID)",
127127
)
128128

agent-memory-client/tests/test_basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_enhanced_methods():
8383

8484
# Test pagination
8585
assert hasattr(client, "search_all_long_term_memories")
86-
assert hasattr(client, "search_all_memories")
8786

8887
# Test enhanced convenience methods
8988
assert hasattr(client, "update_working_memory_data")

agent-memory-client/tests/test_client.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -297,35 +297,6 @@ async def test_search_all_long_term_memories(self, enhanced_test_client):
297297
# Should have made 3 API calls
298298
assert mock_search.call_count == 3
299299

300-
@pytest.mark.asyncio
301-
async def test_search_all_memories(self, enhanced_test_client):
302-
"""Test auto-paginating unified memory search."""
303-
# Similar test for unified search
304-
response = MemoryRecordResults(
305-
total=25,
306-
memories=[
307-
MemoryRecordResult(
308-
id=f"memory-{i}",
309-
text=f"Memory text {i}",
310-
dist=0.1,
311-
)
312-
for i in range(25)
313-
],
314-
next_offset=None,
315-
)
316-
317-
with patch.object(enhanced_test_client, "search_memories") as mock_search:
318-
mock_search.return_value = response
319-
320-
all_memories = []
321-
async for memory in enhanced_test_client.search_all_memories(
322-
text="test query", batch_size=50
323-
):
324-
all_memories.append(memory)
325-
326-
assert len(all_memories) == 25
327-
assert mock_search.call_count == 1
328-
329300

330301
class TestClientSideValidation:
331302
"""Tests for client-side validation methods."""

agent_memory_server/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ async def put_working_memory(
339339

340340
memories = [
341341
MemoryRecord(
342-
id=str(ulid.new()),
342+
id=str(ulid.ULID()),
343343
session_id=session_id,
344344
text=f"{msg.role}: {msg.content}",
345345
namespace=updated_memory.namespace,

agent_memory_server/extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ async def extract_discrete_memories(
333333
if discrete_memories:
334334
long_term_memories = [
335335
MemoryRecord(
336-
id_=str(ulid.new()),
336+
id_=str(ulid.ULID()),
337337
text=new_memory["text"],
338338
memory_type=new_memory.get("type", "episodic"),
339339
topics=new_memory.get("topics", []),

agent_memory_server/long_term_memory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ async def merge_memories_with_llm(memories: list[dict], llm_client: Any = None)
244244
# Create the merged memory
245245
merged_memory = {
246246
"text": merged_text.strip(),
247-
"id_": str(ulid.new()),
247+
"id_": str(ulid.ULID()),
248248
"user_id": user_id,
249249
"session_id": session_id,
250250
"namespace": namespace,
@@ -664,7 +664,7 @@ async def index_long_term_memories(
664664
async with redis.pipeline(transaction=False) as pipe:
665665
for idx, vector in enumerate(embeddings):
666666
memory = processed_memories[idx]
667-
id_ = memory.id if memory.id else str(ulid.new())
667+
id_ = memory.id if memory.id else str(ulid.ULID())
668668
key = Keys.memory_key(id_, memory.namespace)
669669

670670
# Generate memory hash for the memory
@@ -1426,7 +1426,7 @@ async def deduplicate_by_semantic_search(
14261426

14271427
# Convert back to LongTermMemory
14281428
merged_memory_obj = MemoryRecord(
1429-
id=memory.id or str(ulid.new()),
1429+
id=memory.id or str(ulid.ULID()),
14301430
text=merged_memory["text"],
14311431
user_id=merged_memory["user_id"],
14321432
session_id=merged_memory["session_id"],
@@ -1646,7 +1646,7 @@ async def extract_memories_from_messages(
16461646

16471647
# Create a new memory record from the extraction
16481648
extracted_memory = MemoryRecord(
1649-
id=str(ulid.new()), # Server-generated ID
1649+
id=str(ulid.ULID()), # Server-generated ID
16501650
text=memory_data["text"],
16511651
memory_type=memory_data.get("type", "semantic"),
16521652
topics=memory_data.get("topics", []),

agent_memory_server/mcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ async def set_working_memory(
690690
# Handle both MemoryRecord objects and dict inputs
691691
if isinstance(memory, MemoryRecord):
692692
# Already a MemoryRecord object, ensure it has an ID
693-
memory_id = memory.id or str(ulid.new())
693+
memory_id = memory.id or str(ulid.ULID())
694694
processed_memory = memory.model_copy(
695695
update={
696696
"id": memory_id,
@@ -701,7 +701,7 @@ async def set_working_memory(
701701
# Dictionary input, convert to MemoryRecord
702702
memory_dict = dict(memory)
703703
if not memory_dict.get("id"):
704-
memory_dict["id"] = str(ulid.new())
704+
memory_dict["id"] = str(ulid.ULID())
705705
memory_dict["persisted_at"] = None
706706
processed_memory = MemoryRecord(**memory_dict)
707707

agent_memory_server/migrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def migrate_add_discrete_memory_extracted_2(redis: Redis | None = None) ->
9898
id_ = await redis.hget(name=key, key="id_") # type: ignore
9999
if not id_:
100100
logger.info("Updating memory with no ID to set ID")
101-
await redis.hset(name=key, key="id_", value=str(ulid.new())) # type: ignore
101+
await redis.hset(name=key, key="id_", value=str(ulid.ULID())) # type: ignore
102102
# extracted: bytes | None = await redis.hget(
103103
# name=key, key="discrete_memory_extracted"
104104
# ) # type: ignore
@@ -126,7 +126,7 @@ async def migrate_add_memory_type_3(redis: Redis | None = None) -> None:
126126
id_ = await redis.hget(name=key, key="id_") # type: ignore
127127
if not id_:
128128
logger.info("Updating memory with no ID to set ID")
129-
await redis.hset(name=key, key="id_", value=str(ulid.new())) # type: ignore
129+
await redis.hset(name=key, key="id_", value=str(ulid.ULID())) # type: ignore
130130
memory_type: bytes | None = await redis.hget(name=key, key="memory_type") # type: ignore
131131
if not memory_type:
132132
await redis.hset(name=key, key="memory_type", value="message") # type: ignore

0 commit comments

Comments
 (0)