11import logging
22import threading
33from collections .abc import Callable
4+ from contextvars import ContextVar
45from datetime import UTC , datetime , timedelta
56from enum import Enum
67from typing import Any , ClassVar , Literal
@@ -89,8 +90,11 @@ class MemoryMessage(BaseModel):
8990 _warned_message_ids_lock : ClassVar [threading .Lock ] = threading .Lock ()
9091 _max_warned_ids : ClassVar [int ] = 10000 # Prevent unbounded growth
9192
92- # Thread-local storage for passing created_at_provided flag from validator to model_post_init
93- _created_at_provided_thread_local : ClassVar [threading .local ] = threading .local ()
93+ # ContextVar for passing created_at_provided flag from validator to model_post_init
94+ # ContextVar is async-safe (works correctly with coroutines on the same thread)
95+ _created_at_provided_context : ClassVar [ContextVar [bool ]] = ContextVar (
96+ "created_at_provided" , default = False
97+ )
9498
9599 role : str
96100 content : str
@@ -115,14 +119,11 @@ class MemoryMessage(BaseModel):
115119 _created_at_was_provided : bool = PrivateAttr (default = False )
116120
117121 def model_post_init (self , __context : Any ) -> None :
118- """Set _created_at_was_provided from thread-local storage after model is constructed."""
119- # Retrieve the flag from thread-local storage (set by validator)
120- self ._created_at_was_provided = getattr (
121- self ._created_at_provided_thread_local , "value" , False
122- )
123- # Clean up thread-local storage
124- if hasattr (self ._created_at_provided_thread_local , "value" ):
125- del self ._created_at_provided_thread_local .value
122+ """Set _created_at_was_provided from ContextVar after model is constructed."""
123+ # Retrieve the flag from ContextVar (set by validator)
124+ self ._created_at_was_provided = self ._created_at_provided_context .get ()
125+ # Reset ContextVar to default for next use
126+ self ._created_at_provided_context .set (False )
126127
127128 @model_validator (mode = "before" )
128129 @classmethod
@@ -139,8 +140,8 @@ def validate_created_at(cls, data: Any) -> Any:
139140
140141 created_at_provided = "created_at" in data and data ["created_at" ] is not None
141142
142- # Store in thread-local for model_post_init to pick up
143- cls ._created_at_provided_thread_local . value = created_at_provided
143+ # Store in ContextVar for model_post_init to pick up (async-safe)
144+ cls ._created_at_provided_context . set ( created_at_provided )
144145
145146 if not created_at_provided :
146147 # Handle missing created_at
0 commit comments