Skip to content

Commit 738ea87

Browse files
author
Michael Jones
committed
feat(dspy): update DummyLM to use ChatAdapter's format_field to format outputs
1 parent 53b24c8 commit 738ea87

15 files changed

+138
-890
lines changed

dspy/utils/dummies.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dsp.modules import LM as DSPLM
99
from dsp.utils.utils import dotdict
10-
from dspy.adapters.chat_adapter import field_header_pattern
10+
from dspy.adapters.chat_adapter import field_header_pattern, format_fields
1111
from dspy.clients.lm import LM
1212

1313

@@ -98,7 +98,7 @@ def get_convo(self, index) -> str:
9898

9999

100100
class DummyLM(LM):
101-
def __init__(self, answers: Union[list[str], dict[str, str]], follow_examples: bool = False):
101+
def __init__(self, answers: Union[list[dict[str, str]], dict[str, dict[str, str]]], follow_examples: bool = False):
102102
super().__init__("dummy", "chat", 0.0, 1000, True)
103103
self.answers = answers
104104
if isinstance(answers, list):
@@ -133,10 +133,13 @@ def __call__(self, prompt=None, messages=None, **kwargs):
133133
outputs.append(self._use_example(messages))
134134
elif isinstance(self.answers, dict):
135135
outputs.append(
136-
next((v for k, v in self.answers.items() if k in messages[-1]["content"]), "No more responses")
136+
next(
137+
(format_fields(v) for k, v in self.answers.items() if k in messages[-1]["content"]),
138+
"No more responses",
139+
)
137140
)
138141
else:
139-
outputs.append(next(self.answers, "No more responses"))
142+
outputs.append(format_fields(next(self.answers, {"answer": "No more responses"})))
140143

141144
# Logging, with removed api key & where `cost` is None on cache hit.
142145
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}

tests/evaluate/test_evaluate.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def test_evaluate_call():
3737
dspy.settings.configure(
3838
lm=DummyLM(
3939
{
40-
"What is 1+1?": "[[ ## answer ## ]]\n2",
41-
"What is 2+2?": "[[ ## answer ## ]]\n4",
40+
"What is 1+1?": {"answer": "2"},
41+
"What is 2+2?": {"answer": "4"},
4242
}
4343
)
4444
)
@@ -55,9 +55,7 @@ def test_evaluate_call():
5555

5656

5757
def test_multithread_evaluate_call():
58-
dspy.settings.configure(
59-
lm=DummyLM({"What is 1+1?": "[[ ## answer ## ]]\n2", "What is 2+2?": "[[ ## answer ## ]]\n4"})
60-
)
58+
dspy.settings.configure(lm=DummyLM({"What is 1+1?": {"answer": "2"}, "What is 2+2?": {"answer": "4"}}))
6159
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
6260
program = Predict("question -> answer")
6361
assert program(question="What is 1+1?").answer == "2"
@@ -80,9 +78,7 @@ def __call__(self, *args, **kwargs):
8078
time.sleep(1)
8179
return super().__call__(*args, **kwargs)
8280

83-
dspy.settings.configure(
84-
lm=SlowLM({"What is 1+1?": "[[ ## answer ## ]]\n2", "What is 2+2?": "[[ ## answer ## ]]\n4"})
85-
)
81+
dspy.settings.configure(lm=SlowLM({"What is 1+1?": {"answer": "2"}, "What is 2+2?": {"answer": "4"}}))
8682

8783
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
8884
program = Predict("question -> answer")
@@ -112,9 +108,7 @@ def sleep_then_interrupt():
112108

113109

114110
def test_evaluate_call_bad():
115-
dspy.settings.configure(
116-
lm=DummyLM({"What is 1+1?": "[[ ## answer ## ]]\n0", "What is 2+2?": "[[ ## answer ## ]]\n0"})
117-
)
111+
dspy.settings.configure(lm=DummyLM({"What is 1+1?": {"answer": "0"}, "What is 2+2?": {"answer": "0"}}))
118112
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
119113
program = Predict("question -> answer")
120114
ev = Evaluate(

0 commit comments

Comments
 (0)