11import copy
22import threading
3-
43from contextlib import contextmanager
5- from contextvars import ContextVar
64from dsp .utils .utils import dotdict
75
86DEFAULT_CONFIG = dotdict (
3129# Global base configuration
3230main_thread_config = copy .deepcopy (DEFAULT_CONFIG )
3331
34- # Initialize the context variable with an empty dict as default
35- dspy_ctx_overrides = ContextVar ('dspy_ctx_overrides' , default = dotdict ())
32+
33+ class ThreadLocalOverrides (threading .local ):
34+ def __init__ (self ):
35+ self .overrides = dotdict () # Initialize thread-local overrides
36+
37+
38+ # Create the thread-local storage
39+ thread_local_overrides = ThreadLocalOverrides ()
3640
3741
3842class Settings :
@@ -53,7 +57,7 @@ def __new__(cls):
5357 return cls ._instance
5458
5559 def __getattr__ (self , name ):
56- overrides = dspy_ctx_overrides . get ( )
60+ overrides = getattr ( thread_local_overrides , 'overrides' , dotdict () )
5761 if name in overrides :
5862 return overrides [name ]
5963 elif name in main_thread_config :
@@ -76,7 +80,7 @@ def __setitem__(self, key, value):
7680 self .__setattr__ (key , value )
7781
7882 def __contains__ (self , key ):
79- overrides = dspy_ctx_overrides . get ( )
83+ overrides = getattr ( thread_local_overrides , 'overrides' , dotdict () )
8084 return key in overrides or key in main_thread_config
8185
8286 def get (self , key , default = None ):
@@ -86,45 +90,49 @@ def get(self, key, default=None):
8690 return default
8791
8892 def copy (self ):
89- overrides = dspy_ctx_overrides . get ( )
93+ overrides = getattr ( thread_local_overrides , 'overrides' , dotdict () )
9094 return dotdict ({** main_thread_config , ** overrides })
9195
9296 @property
9397 def config (self ):
9498 config = self .copy ()
95- del config ['lock' ]
99+ if 'lock' in config :
100+ del config ['lock' ]
96101 return config
97102
98103 # Configuration methods
99104
100- def configure (self , return_token = False , ** kwargs ):
105+ def configure (self , ** kwargs ):
101106 global main_thread_config
102- overrides = dspy_ctx_overrides .get ()
103- new_overrides = dotdict ({** copy .deepcopy (DEFAULT_CONFIG ), ** main_thread_config , ** overrides , ** kwargs })
104- token = dspy_ctx_overrides .set (new_overrides )
107+
108+ # Get or initialize thread-local overrides
109+ overrides = getattr (thread_local_overrides , 'overrides' , dotdict ())
110+ thread_local_overrides .overrides = dotdict (
111+ {** copy .deepcopy (DEFAULT_CONFIG ), ** main_thread_config , ** overrides , ** kwargs }
112+ )
105113
106114 # Update main_thread_config, in the main thread only
107115 if threading .current_thread () is threading .main_thread ():
108- main_thread_config = new_overrides
109-
110- if return_token :
111- return token
116+ main_thread_config = thread_local_overrides .overrides
112117
113118 @contextmanager
114119 def context (self , ** kwargs ):
115120 """Context manager for temporary configuration changes."""
116- token = self .configure (return_token = True , ** kwargs )
121+ global main_thread_config
122+ original_overrides = getattr (thread_local_overrides , 'overrides' , dotdict ()).copy ()
123+ original_main_thread_config = main_thread_config .copy ()
124+
125+ self .configure (** kwargs )
117126 try :
118127 yield
119128 finally :
120- dspy_ctx_overrides . reset ( token )
129+ thread_local_overrides . overrides = original_overrides
121130
122131 if threading .current_thread () is threading .main_thread ():
123- global main_thread_config
124- main_thread_config = dotdict ({** copy .deepcopy (DEFAULT_CONFIG ), ** dspy_ctx_overrides .get ()})
132+ main_thread_config = original_main_thread_config
125133
126134 def __repr__ (self ):
127- overrides = dspy_ctx_overrides . get ( )
135+ overrides = getattr ( thread_local_overrides , 'overrides' , dotdict () )
128136 combined_config = {** main_thread_config , ** overrides }
129137 return repr (combined_config )
130138
0 commit comments