Skip to content

Commit dbd2325

Browse files
committed
add contextVar
1 parent 8629fc3 commit dbd2325

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

agent_memory_server/models.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import threading
33
from collections.abc import Callable
4+
from contextvars import ContextVar
45
from datetime import UTC, datetime, timedelta
56
from enum import Enum
67
from typing import Any, ClassVar, Literal
@@ -89,8 +90,11 @@ class MemoryMessage(BaseModel):
8990
_warned_message_ids_lock: ClassVar[threading.Lock] = threading.Lock()
9091
_max_warned_ids: ClassVar[int] = 10000 # Prevent unbounded growth
9192

92-
# Thread-local storage for passing created_at_provided flag from validator to model_post_init
93-
_created_at_provided_thread_local: ClassVar[threading.local] = threading.local()
93+
# ContextVar for passing created_at_provided flag from validator to model_post_init
94+
# ContextVar is async-safe (works correctly with coroutines on the same thread)
95+
_created_at_provided_context: ClassVar[ContextVar[bool]] = ContextVar(
96+
"created_at_provided", default=False
97+
)
9498

9599
role: str
96100
content: str
@@ -115,14 +119,11 @@ class MemoryMessage(BaseModel):
115119
_created_at_was_provided: bool = PrivateAttr(default=False)
116120

117121
def model_post_init(self, __context: Any) -> None:
118-
"""Set _created_at_was_provided from thread-local storage after model is constructed."""
119-
# Retrieve the flag from thread-local storage (set by validator)
120-
self._created_at_was_provided = getattr(
121-
self._created_at_provided_thread_local, "value", False
122-
)
123-
# Clean up thread-local storage
124-
if hasattr(self._created_at_provided_thread_local, "value"):
125-
del self._created_at_provided_thread_local.value
122+
"""Set _created_at_was_provided from ContextVar after model is constructed."""
123+
# Retrieve the flag from ContextVar (set by validator)
124+
self._created_at_was_provided = self._created_at_provided_context.get()
125+
# Reset ContextVar to default for next use
126+
self._created_at_provided_context.set(False)
126127

127128
@model_validator(mode="before")
128129
@classmethod
@@ -139,8 +140,8 @@ def validate_created_at(cls, data: Any) -> Any:
139140

140141
created_at_provided = "created_at" in data and data["created_at"] is not None
141142

142-
# Store in thread-local for model_post_init to pick up
143-
cls._created_at_provided_thread_local.value = created_at_provided
143+
# Store in ContextVar for model_post_init to pick up (async-safe)
144+
cls._created_at_provided_context.set(created_at_provided)
144145

145146
if not created_at_provided:
146147
# Handle missing created_at

0 commit comments

Comments
 (0)