Skip to content

Commit 7dcb08a

Browse files
committed
Fix mypy errors
1 parent d15f1dc commit 7dcb08a

File tree

4 files changed

+37
-47
lines changed

4 files changed

+37
-47
lines changed

agent-memory-client/agent_memory_client/client.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import contextlib
99
import re
1010
from collections.abc import AsyncIterator
11-
from datetime import datetime
12-
from typing import Any, Literal
11+
from typing import TYPE_CHECKING, Any, Literal
12+
13+
if TYPE_CHECKING:
14+
from typing_extensions import Self
1315

1416
import httpx
1517
import ulid
@@ -71,15 +73,15 @@ def __init__(self, config: MemoryClientConfig):
7173
timeout=config.timeout,
7274
)
7375

74-
async def close(self):
76+
async def close(self) -> None:
7577
"""Close the underlying HTTP client."""
7678
await self._client.aclose()
7779

78-
async def __aenter__(self):
80+
async def __aenter__(self) -> "Self":
7981
"""Support using the client as an async context manager."""
8082
return self
8183

82-
async def __aexit__(self, exc_type, exc_val, exc_tb):
84+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
8385
"""Close the client when exiting the context manager."""
8486
await self.close()
8587

@@ -176,13 +178,13 @@ async def get_session_memory(
176178
params["namespace"] = self.config.default_namespace
177179

178180
if window_size is not None:
179-
params["window_size"] = window_size
181+
params["window_size"] = str(window_size)
180182

181183
if model_name is not None:
182184
params["model_name"] = model_name
183185

184186
if context_window_max is not None:
185-
params["context_window_max"] = context_window_max
187+
params["context_window_max"] = str(context_window_max)
186188

187189
try:
188190
response = await self._client.get(
@@ -861,31 +863,11 @@ def validate_memory_record(self, memory: ClientMemoryRecord | MemoryRecord) -> N
861863
if memory.id and not self._is_valid_ulid(memory.id):
862864
raise MemoryValidationError(f"Invalid ID format: {memory.id}")
863865

864-
if (
865-
hasattr(memory, "created_at")
866-
and memory.created_at
867-
and not isinstance(memory.created_at, datetime)
868-
):
869-
try:
870-
datetime.fromisoformat(str(memory.created_at))
871-
except ValueError as e:
872-
raise MemoryValidationError(
873-
f"Invalid created_at format: {memory.created_at}"
874-
) from e
875-
876-
if (
877-
hasattr(memory, "last_accessed")
878-
and memory.last_accessed
879-
and not isinstance(memory.last_accessed, datetime)
880-
):
881-
try:
882-
datetime.fromisoformat(str(memory.last_accessed))
883-
except ValueError as e:
884-
raise MemoryValidationError(
885-
f"Invalid last_accessed format: {memory.last_accessed}"
886-
) from e
866+
# created_at is validated by Pydantic
887867

888-
def validate_search_filters(self, **filters) -> None:
868+
# last_accessed is validated by Pydantic
869+
870+
def validate_search_filters(self, **filters: Any) -> None:
889871
"""Validate search filter parameters before API call."""
890872
valid_filter_keys = {
891873
"session_id",
@@ -1022,7 +1004,10 @@ async def append_messages_to_working_memory(
10221004
{"role": msg.role, "content": msg.content}
10231005
)
10241006
else:
1025-
converted_existing_messages.append(msg)
1007+
# Fallback for any other message type
1008+
converted_existing_messages.append(
1009+
{"role": "user", "content": str(msg)}
1010+
)
10261011

10271012
# Convert new messages to dict format if they're objects
10281013
new_messages = []
@@ -1074,21 +1059,21 @@ async def memory_prompt(
10741059
Returns:
10751060
Dict with messages hydrated with relevant memory context
10761061
"""
1077-
payload = {"query": query}
1062+
payload: dict[str, Any] = {"query": query}
10781063

10791064
# Add session parameters if provided
10801065
if session_id is not None:
1081-
session_params = {"session_id": session_id}
1066+
session_params: dict[str, Any] = {"session_id": session_id}
10821067
if namespace is not None:
10831068
session_params["namespace"] = namespace
10841069
elif self.config.default_namespace is not None:
10851070
session_params["namespace"] = self.config.default_namespace
10861071
if window_size is not None:
1087-
session_params["window_size"] = window_size
1072+
session_params["window_size"] = str(window_size)
10881073
if model_name is not None:
10891074
session_params["model_name"] = model_name
10901075
if context_window_max is not None:
1091-
session_params["context_window_max"] = context_window_max
1076+
session_params["context_window_max"] = str(context_window_max)
10921077
payload["session"] = session_params
10931078

10941079
# Add long-term search parameters if provided
@@ -1101,7 +1086,10 @@ async def memory_prompt(
11011086
json=payload,
11021087
)
11031088
response.raise_for_status()
1104-
return response.json()
1089+
result = response.json()
1090+
if isinstance(result, dict):
1091+
return result
1092+
return {"response": result}
11051093
except httpx.HTTPStatusError as e:
11061094
self._handle_http_error(e.response)
11071095
raise
@@ -1143,7 +1131,7 @@ async def hydrate_memory_prompt(
11431131
Dict with messages hydrated with relevant long-term memories
11441132
"""
11451133
# Build long-term search parameters
1146-
long_term_search = {"limit": limit}
1134+
long_term_search: dict[str, Any] = {"limit": limit}
11471135

11481136
if session_id is not None:
11491137
long_term_search["session_id"] = session_id
@@ -1171,7 +1159,9 @@ async def hydrate_memory_prompt(
11711159
long_term_search=long_term_search,
11721160
)
11731161

1174-
def _deep_merge_dicts(self, base: dict, updates: dict) -> dict:
1162+
def _deep_merge_dicts(
1163+
self, base: dict[str, Any], updates: dict[str, Any]
1164+
) -> dict[str, Any]:
11751165
"""Recursively merge two dictionaries."""
11761166
result = base.copy()
11771167
for key, value in updates.items():

agent-memory-client/agent_memory_client/models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
For full model definitions, see the main agent_memory_server package.
66
"""
77

8-
from datetime import UTC, datetime
8+
from datetime import datetime, timezone
99
from enum import Enum
1010
from typing import Any, Literal
1111

@@ -73,16 +73,16 @@ class MemoryRecord(BaseModel):
7373
description="Optional namespace for the memory record",
7474
)
7575
last_accessed: datetime = Field(
76-
default_factory=lambda: datetime.now(UTC),
76+
default_factory=lambda: datetime.now(timezone.utc),
7777
description="Datetime when the memory was last accessed",
7878
)
7979
created_at: datetime = Field(
80-
default_factory=lambda: datetime.now(UTC),
80+
default_factory=lambda: datetime.now(timezone.utc),
8181
description="Datetime when the memory was created",
8282
)
8383
updated_at: datetime = Field(
8484
description="Datetime when the memory was last updated",
85-
default_factory=lambda: datetime.now(UTC),
85+
default_factory=lambda: datetime.now(timezone.utc),
8686
)
8787
topics: list[str] | None = Field(
8888
default=None,
@@ -127,7 +127,7 @@ class ClientMemoryRecord(MemoryRecord):
127127
)
128128

129129

130-
JSONTypes = str | float | int | bool | list | dict
130+
JSONTypes = str | float | int | bool | list[Any] | dict[str, Any]
131131

132132

133133
class WorkingMemory(BaseModel):
@@ -176,7 +176,7 @@ class WorkingMemory(BaseModel):
176176
description="TTL for the working memory in seconds",
177177
)
178178
last_accessed: datetime = Field(
179-
default_factory=lambda: datetime.now(UTC),
179+
default_factory=lambda: datetime.now(timezone.utc),
180180
description="Datetime when the working memory was last accessed",
181181
)
182182

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ license = { text = "MIT" }
1212
authors = [{ name = "Andrew Brookins", email = "[email protected]" }]
1313
dependencies = [
1414
"accelerate>=1.6.0",
15-
"agent-memory-client @ git+https://github.com/username/agent-memory-client@main",
15+
"agent-memory-client",
1616
"anthropic>=0.15.0",
1717
"bertopic<0.17.0,>=0.16.4",
1818
"fastapi>=0.115.11",
@@ -132,6 +132,7 @@ dev = [
132132
"testcontainers>=3.7.0",
133133
"pre-commit>=3.6.0",
134134
"freezegun>=1.2.0",
135+
"-e ./agent-memory-client",
135136
]
136137

137138
[tool.ruff.lint.per-file-ignores]

uv.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)