Skip to content

Commit 4f2d2f8

Browse files
Introduce dspy.Refine, dspy.BestOfN and fix thread safety for dspy.ChainOfThoughtWithHint (#1959)
* Add DSPy Refine * Introduce BestOfN and make Refine & CoTWithHint thread safe * Improve Refine implementation --------- Co-authored-by: Omar Khattab <[email protected]> Co-authored-by: Omar Khattab <[email protected]>
1 parent a24856c commit 4f2d2f8

File tree

6 files changed

+216
-469
lines changed

6 files changed

+216
-469
lines changed

dspy/predict/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
from dspy.predict.aggregation import majority
2+
from dspy.predict.best_of_n import BestOfN
23
from dspy.predict.chain_of_thought import ChainOfThought
34
from dspy.predict.chain_of_thought_with_hint import ChainOfThoughtWithHint
45
from dspy.predict.knn import KNN
56
from dspy.predict.multi_chain_comparison import MultiChainComparison
67
from dspy.predict.predict import Predict
78
from dspy.predict.program_of_thought import ProgramOfThought
8-
from dspy.predict.react import ReAct
9+
from dspy.predict.react import ReAct, Tool
10+
from dspy.predict.refine import Refine
911
from dspy.predict.parallel import Parallel
1012

1113
__all__ = [
1214
"majority",
15+
"BestOfN",
1316
"ChainOfThought",
1417
"ChainOfThoughtWithHint",
1518
"KNN",
1619
"MultiChainComparison",
1720
"Predict",
1821
"ProgramOfThought",
1922
"ReAct",
23+
"Refine",
24+
"Tool",
2025
"Parallel",
2126
]

dspy/predict/best_of_n.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import dspy
2+
3+
from .predict import Module
4+
from typing import Callable
5+
6+
7+
class BestOfN(Module):
8+
def __init__(self, module, N: int, reward_fn: Callable, threshold: float):
9+
self.module = module
10+
self.reward_fn = lambda *args: reward_fn(*args) # to prevent this from becoming a parameter
11+
self.threshold = threshold
12+
self.N = N
13+
14+
def forward(self, **kwargs):
15+
lm = self.module.get_lm() or dspy.settings.lm
16+
temps = [lm.kwargs['temperature']] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
17+
temps = list(dict.fromkeys(temps))[:self.N]
18+
best_pred, best_trace, best_reward = None, None, -float("inf")
19+
20+
for t in temps:
21+
lm_ = lm.copy(temperature=t)
22+
mod = self.module.deepcopy()
23+
mod.set_lm(lm_)
24+
25+
try:
26+
with dspy.context(trace=[]):
27+
pred = mod(**kwargs)
28+
trace = dspy.settings.trace.copy()
29+
30+
# NOTE: Not including the trace of reward_fn.
31+
reward = self.reward_fn(kwargs, pred)
32+
33+
if reward > best_reward:
34+
best_reward, best_pred, best_trace = reward, pred, trace
35+
36+
if self.threshold is not None and reward >= self.threshold:
37+
break
38+
39+
except Exception as e:
40+
print(f"Attempt failed with temperature {t}: {e}")
41+
42+
dspy.settings.trace.extend(best_trace)
43+
return best_pred
Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import dspy
2-
32
from .predict import Module
43

5-
# TODO: FIXME: Insert this right before the *first* output field. Also rewrite this to use the new signature system.
6-
74
class ChainOfThoughtWithHint(Module):
85
def __init__(self, signature, rationale_type=None, **config):
96
self.signature = dspy.ensure_signature(signature)
@@ -19,16 +16,11 @@ def forward(self, **kwargs):
1916
kwargs[last_key] = str(kwargs[last_key]) + hint
2017

2118
# Run CoT then update the trace with original kwargs, i.e. without the hint.
22-
pred = self.module(**kwargs)
23-
this_trace = dspy.settings.trace[-1]
24-
dspy.settings.trace[-1] = (this_trace[0], original_kwargs, this_trace[2])
19+
with dspy.context(trace=[]):
20+
pred = self.module(**kwargs)
21+
this_trace = dspy.settings.trace[-1]
22+
23+
dspy.settings.trace.append((this_trace[0], original_kwargs, this_trace[2]))
2524
return pred
2625

2726
return self.module(**kwargs)
28-
29-
30-
"""
31-
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.
32-
33-
IF the user didn't overwrite our default rationale_type.
34-
"""

dspy/predict/langchain.py

Lines changed: 0 additions & 192 deletions
This file was deleted.

0 commit comments

Comments
 (0)