11import logging
2+ import threading
23from collections .abc import Callable
34from datetime import UTC , datetime , timedelta
45from enum import Enum
@@ -85,8 +86,12 @@ class MemoryMessage(BaseModel):
8586 # Track message IDs that have been warned (in-memory, per-worker)
8687 # Used to rate-limit deprecation warnings
8788 _warned_message_ids : ClassVar [set [str ]] = set ()
89+ _warned_message_ids_lock : ClassVar [threading .Lock ] = threading .Lock ()
8890 _max_warned_ids : ClassVar [int ] = 10000 # Prevent unbounded growth
8991
92+ # Thread-local storage for passing created_at_provided flag from validator to model_post_init
93+ _created_at_provided_thread_local : ClassVar [threading .local ] = threading .local ()
94+
9095 role : str
9196 content : str
9297 id : str = Field (
@@ -109,11 +114,15 @@ class MemoryMessage(BaseModel):
109114 # Used for deprecation header in API responses
110115 _created_at_was_provided : bool = PrivateAttr (default = False )
111116
112- def __init__ (self , ** data ):
113- # Check if created_at was provided before calling super().__init__
114- created_at_provided = "created_at" in data and data ["created_at" ] is not None
115- super ().__init__ (** data )
116- self ._created_at_was_provided = created_at_provided
117+ def model_post_init (self , __context : Any ) -> None :
118+ """Set _created_at_was_provided from thread-local storage after model is constructed."""
119+ # Retrieve the flag from thread-local storage (set by validator)
120+ self ._created_at_was_provided = getattr (
121+ self ._created_at_provided_thread_local , "value" , False
122+ )
123+ # Clean up thread-local storage
124+ if hasattr (self ._created_at_provided_thread_local , "value" ):
125+ del self ._created_at_provided_thread_local .value
117126
118127 @model_validator (mode = "before" )
119128 @classmethod
@@ -130,22 +139,30 @@ def validate_created_at(cls, data: Any) -> Any:
130139
131140 created_at_provided = "created_at" in data and data ["created_at" ] is not None
132141
142+ # Store in thread-local for model_post_init to pick up
143+ cls ._created_at_provided_thread_local .value = created_at_provided
144+
133145 if not created_at_provided :
134146 # Handle missing created_at
135147 if settings .require_message_timestamps :
136148 raise ValueError (
137149 "created_at is required for messages. "
138150 "Please provide the timestamp when the message was created."
139151 )
140- # Rate-limit warnings by message ID
152+ # Rate-limit warnings by message ID (thread-safe)
141153 msg_id = data .get ("id" , "unknown" )
142154
143- if msg_id not in cls ._warned_message_ids :
144- # Prevent unbounded memory growth
145- if len (cls ._warned_message_ids ) >= cls ._max_warned_ids :
146- cls ._warned_message_ids .clear ()
147- cls ._warned_message_ids .add (msg_id )
155+ with cls ._warned_message_ids_lock :
156+ if msg_id not in cls ._warned_message_ids :
157+ # Prevent unbounded memory growth
158+ if len (cls ._warned_message_ids ) >= cls ._max_warned_ids :
159+ cls ._warned_message_ids .clear ()
160+ cls ._warned_message_ids .add (msg_id )
161+ should_warn = True
162+ else :
163+ should_warn = False
148164
165+ if should_warn :
149166 logger .warning (
150167 "MemoryMessage created without explicit created_at timestamp. "
151168 "This will become required in a future version. "
0 commit comments