Skip to content

Commit 43ad96b

Browse files
committed
defensive lock
1 parent a7b0585 commit 43ad96b

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

agent-memory-client/agent_memory_client/models.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import logging
9+
import threading
910
from datetime import datetime, timedelta, timezone
1011
from enum import Enum
1112
from typing import Any, ClassVar, Literal
@@ -68,6 +69,7 @@ class MemoryMessage(BaseModel):
6869
# Track message IDs that have been warned (in-memory, per-process)
6970
# Used to rate-limit deprecation warnings
7071
_warned_message_ids: ClassVar[set[str]] = set()
72+
_warned_message_ids_lock: ClassVar[threading.Lock] = threading.Lock()
7173
_max_warned_ids: ClassVar[int] = 10000 # Prevent unbounded growth
7274

7375
# Default tolerance for future timestamp validation (5 minutes)
@@ -106,15 +108,20 @@ def validate_created_at(cls, data: Any) -> Any:
106108
created_at_provided = "created_at" in data and data["created_at"] is not None
107109

108110
if not created_at_provided:
109-
# Rate-limit warnings by message ID
111+
# Rate-limit warnings by message ID (thread-safe)
110112
msg_id = data.get("id", "unknown")
111113

112-
if msg_id not in cls._warned_message_ids:
113-
# Prevent unbounded memory growth
114-
if len(cls._warned_message_ids) >= cls._max_warned_ids:
115-
cls._warned_message_ids.clear()
116-
cls._warned_message_ids.add(msg_id)
117-
114+
with cls._warned_message_ids_lock:
115+
if msg_id not in cls._warned_message_ids:
116+
# Prevent unbounded memory growth
117+
if len(cls._warned_message_ids) >= cls._max_warned_ids:
118+
cls._warned_message_ids.clear()
119+
cls._warned_message_ids.add(msg_id)
120+
should_warn = True
121+
else:
122+
should_warn = False
123+
124+
if should_warn:
118125
logger.warning(
119126
"MemoryMessage created without explicit created_at timestamp. "
120127
"This will become required in a future version. "

agent_memory_server/models.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import threading
23
from collections.abc import Callable
34
from datetime import UTC, datetime, timedelta
45
from enum import Enum
@@ -85,8 +86,12 @@ class MemoryMessage(BaseModel):
8586
# Track message IDs that have been warned (in-memory, per-worker)
8687
# Used to rate-limit deprecation warnings
8788
_warned_message_ids: ClassVar[set[str]] = set()
89+
_warned_message_ids_lock: ClassVar[threading.Lock] = threading.Lock()
8890
_max_warned_ids: ClassVar[int] = 10000 # Prevent unbounded growth
8991

92+
# Thread-local storage for passing created_at_provided flag from validator to model_post_init
93+
_created_at_provided_thread_local: ClassVar[threading.local] = threading.local()
94+
9095
role: str
9196
content: str
9297
id: str = Field(
@@ -109,11 +114,15 @@ class MemoryMessage(BaseModel):
109114
# Used for deprecation header in API responses
110115
_created_at_was_provided: bool = PrivateAttr(default=False)
111116

112-
def __init__(self, **data):
113-
# Check if created_at was provided before calling super().__init__
114-
created_at_provided = "created_at" in data and data["created_at"] is not None
115-
super().__init__(**data)
116-
self._created_at_was_provided = created_at_provided
117+
def model_post_init(self, __context: Any) -> None:
118+
"""Set _created_at_was_provided from thread-local storage after model is constructed."""
119+
# Retrieve the flag from thread-local storage (set by validator)
120+
self._created_at_was_provided = getattr(
121+
self._created_at_provided_thread_local, "value", False
122+
)
123+
# Clean up thread-local storage
124+
if hasattr(self._created_at_provided_thread_local, "value"):
125+
del self._created_at_provided_thread_local.value
117126

118127
@model_validator(mode="before")
119128
@classmethod
@@ -130,22 +139,30 @@ def validate_created_at(cls, data: Any) -> Any:
130139

131140
created_at_provided = "created_at" in data and data["created_at"] is not None
132141

142+
# Store in thread-local for model_post_init to pick up
143+
cls._created_at_provided_thread_local.value = created_at_provided
144+
133145
if not created_at_provided:
134146
# Handle missing created_at
135147
if settings.require_message_timestamps:
136148
raise ValueError(
137149
"created_at is required for messages. "
138150
"Please provide the timestamp when the message was created."
139151
)
140-
# Rate-limit warnings by message ID
152+
# Rate-limit warnings by message ID (thread-safe)
141153
msg_id = data.get("id", "unknown")
142154

143-
if msg_id not in cls._warned_message_ids:
144-
# Prevent unbounded memory growth
145-
if len(cls._warned_message_ids) >= cls._max_warned_ids:
146-
cls._warned_message_ids.clear()
147-
cls._warned_message_ids.add(msg_id)
155+
with cls._warned_message_ids_lock:
156+
if msg_id not in cls._warned_message_ids:
157+
# Prevent unbounded memory growth
158+
if len(cls._warned_message_ids) >= cls._max_warned_ids:
159+
cls._warned_message_ids.clear()
160+
cls._warned_message_ids.add(msg_id)
161+
should_warn = True
162+
else:
163+
should_warn = False
148164

165+
if should_warn:
149166
logger.warning(
150167
"MemoryMessage created without explicit created_at timestamp. "
151168
"This will become required in a future version. "

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def test_message_with_future_timestamp_rejected(self):
256256
created_at=future_time,
257257
)
258258

259-
assert "cannot be in the future" in str(exc_info.value)
259+
assert "cannot be more than" in str(exc_info.value)
260+
assert "seconds in the future" in str(exc_info.value)
260261

261262
def test_message_with_near_future_timestamp_allowed(self):
262263
"""Test that timestamps within tolerance are allowed"""

0 commit comments

Comments
 (0)