Skip to content

Commit 2e42988

Browse files
authored
Fix asyncify and relax (warning) on settings read from unrecognized thread (#1813)
1 parent 997b1d8 commit 2e42988

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

dsp/utils/settings.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from contextlib import contextmanager
55
from dsp.utils.utils import dotdict
6+
from functools import lru_cache
67

78
DEFAULT_CONFIG = dotdict(
89
lm=None,
@@ -27,6 +28,12 @@
2728
async_max_workers=8,
2829
)
2930

31+
@lru_cache(maxsize=None)
32+
def warn_once(msg: str):
33+
import logging
34+
logger = logging.getLogger(__name__)
35+
logger.warning(msg)
36+
3037

3138
class Settings:
3239
"""DSP configuration settings."""
@@ -59,7 +66,11 @@ def config(self):
5966
thread_id = threading.get_ident()
6067
# if thread_id not in self.stack_by_thread:
6168
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
62-
return self.stack_by_thread[thread_id][-1]
69+
try:
70+
return self.stack_by_thread[thread_id][-1]
71+
except Exception:
72+
warn_once("Warning: You seem to be creating DSPy threads in an unsupported way.")
73+
return self.main_stack[-1]
6374

6475
def __getattr__(self, name):
6576
if hasattr(self.config, name):
@@ -74,6 +85,8 @@ def __append(self, config):
7485
thread_id = threading.get_ident()
7586
# if thread_id not in self.stack_by_thread:
7687
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
88+
89+
assert thread_id in self.stack_by_thread, "Error: You seem to be creating DSPy threads in an unsupported way."
7790
self.stack_by_thread[thread_id].append(config)
7891

7992
def __pop(self):

dspy/utils/asyncify.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,22 @@ def get_limiter():
2424

2525

2626
def asyncify(program):
27-
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())
27+
import dspy
28+
import threading
29+
30+
assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread"
31+
32+
def wrapped(*args, **kwargs):
33+
thread_stacks = dspy.settings.stack_by_thread
34+
current_thread_id = threading.get_ident()
35+
creating_new_thread = current_thread_id not in thread_stacks
36+
37+
assert creating_new_thread
38+
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)
39+
40+
try:
41+
return program(*args, **kwargs)
42+
finally:
43+
del thread_stacks[threading.get_ident()]
44+
45+
return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter())

0 commit comments

Comments
 (0)