Skip to content

Commit 74b19c8

Browse files
authored
Switch settings from contextvar to thread local storage (for Colab) (#1860)
1 parent c4f1f95 commit 74b19c8

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

dsp/utils/settings.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import copy
22
import threading
3-
43
from contextlib import contextmanager
5-
from contextvars import ContextVar
64
from dsp.utils.utils import dotdict
75

86
DEFAULT_CONFIG = dotdict(
@@ -31,8 +29,14 @@
3129
# Global base configuration
3230
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)
3331

34-
# Initialize the context variable with an empty dict as default
35-
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())
32+
33+
class ThreadLocalOverrides(threading.local):
34+
def __init__(self):
35+
self.overrides = dotdict() # Initialize thread-local overrides
36+
37+
38+
# Create the thread-local storage
39+
thread_local_overrides = ThreadLocalOverrides()
3640

3741

3842
class Settings:
@@ -53,7 +57,7 @@ def __new__(cls):
5357
return cls._instance
5458

5559
def __getattr__(self, name):
56-
overrides = dspy_ctx_overrides.get()
60+
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
5761
if name in overrides:
5862
return overrides[name]
5963
elif name in main_thread_config:
@@ -76,7 +80,7 @@ def __setitem__(self, key, value):
7680
self.__setattr__(key, value)
7781

7882
def __contains__(self, key):
79-
overrides = dspy_ctx_overrides.get()
83+
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
8084
return key in overrides or key in main_thread_config
8185

8286
def get(self, key, default=None):
@@ -86,45 +90,49 @@ def get(self, key, default=None):
8690
return default
8791

8892
def copy(self):
89-
overrides = dspy_ctx_overrides.get()
93+
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
9094
return dotdict({**main_thread_config, **overrides})
9195

9296
@property
9397
def config(self):
9498
config = self.copy()
95-
del config['lock']
99+
if 'lock' in config:
100+
del config['lock']
96101
return config
97102

98103
# Configuration methods
99104

100-
def configure(self, return_token=False, **kwargs):
105+
def configure(self, **kwargs):
101106
global main_thread_config
102-
overrides = dspy_ctx_overrides.get()
103-
new_overrides = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs})
104-
token = dspy_ctx_overrides.set(new_overrides)
107+
108+
# Get or initialize thread-local overrides
109+
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
110+
thread_local_overrides.overrides = dotdict(
111+
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
112+
)
105113

106114
# Update main_thread_config, in the main thread only
107115
if threading.current_thread() is threading.main_thread():
108-
main_thread_config = new_overrides
109-
110-
if return_token:
111-
return token
116+
main_thread_config = thread_local_overrides.overrides
112117

113118
@contextmanager
114119
def context(self, **kwargs):
115120
"""Context manager for temporary configuration changes."""
116-
token = self.configure(return_token=True, **kwargs)
121+
global main_thread_config
122+
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
123+
original_main_thread_config = main_thread_config.copy()
124+
125+
self.configure(**kwargs)
117126
try:
118127
yield
119128
finally:
120-
dspy_ctx_overrides.reset(token)
129+
thread_local_overrides.overrides = original_overrides
121130

122131
if threading.current_thread() is threading.main_thread():
123-
global main_thread_config
124-
main_thread_config = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **dspy_ctx_overrides.get()})
132+
main_thread_config = original_main_thread_config
125133

126134
def __repr__(self):
127-
overrides = dspy_ctx_overrides.get()
135+
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
128136
combined_config = {**main_thread_config, **overrides}
129137
return repr(combined_config)
130138

dspy/utils/parallelizer.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
import threading
66
import traceback
77
import contextlib
8-
9-
from contextvars import copy_context
108
from tqdm.contrib.logging import logging_redirect_tqdm
119
from concurrent.futures import ThreadPoolExecutor, as_completed
1210

1311
logger = logging.getLogger(__name__)
1412

15-
1613
class ParallelExecutor:
1714
def __init__(
1815
self,
@@ -80,10 +77,16 @@ def _execute_isolated_single_thread(self, function, data):
8077
if self.cancel_jobs.is_set():
8178
break
8279

83-
# Create an isolated context for each task
84-
task_ctx = copy_context()
85-
result = task_ctx.run(function, item)
86-
results.append(result)
80+
# Create an isolated context for each task using thread-local overrides
81+
from dsp.utils.settings import thread_local_overrides
82+
original_overrides = thread_local_overrides.overrides
83+
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()
84+
85+
try:
86+
result = function(item)
87+
results.append(result)
88+
finally:
89+
thread_local_overrides.overrides = original_overrides
8790

8891
if self.compare_results:
8992
# Assumes score is the last element of the result tuple
@@ -137,18 +140,30 @@ def interrupt_handler(sig, frame):
137140
# If not in the main thread, skip setting signal handlers
138141
yield
139142

140-
def cancellable_function(index_item):
143+
def cancellable_function(parent_overrides, index_item):
141144
index, item = index_item
142145
if self.cancel_jobs.is_set():
143146
return index, job_cancelled
144-
return index, function(item)
147+
148+
# Create an isolated context for each task using thread-local overrides
149+
from dsp.utils.settings import thread_local_overrides
150+
original_overrides = thread_local_overrides.overrides
151+
thread_local_overrides.overrides = parent_overrides.copy()
152+
153+
try:
154+
return index, function(item)
155+
finally:
156+
thread_local_overrides.overrides = original_overrides
145157

146158
with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
159+
# Capture the parent thread's overrides
160+
from dsp.utils.settings import thread_local_overrides
161+
parent_overrides = thread_local_overrides.overrides.copy()
162+
147163
futures = {}
148164
for pair in enumerate(data):
149-
# Capture the context for each task
150-
task_ctx = copy_context()
151-
future = executor.submit(task_ctx.run, cancellable_function, pair)
165+
# Pass the parent thread's overrides to each thread
166+
future = executor.submit(cancellable_function, parent_overrides, pair)
152167
futures[future] = pair
153168

154169
pbar = tqdm.tqdm(

0 commit comments

Comments
 (0)