Skip to content

Commit a68f2d9

Browse files
authored
Fix error in Evaluate with display_table=True with outputs that cannot be converted to dict (#1682)
* fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> --------- Signed-off-by: dbczumar <[email protected]>
1 parent 56dec59 commit a68f2d9

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

dspy/evaluate/evaluate.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,20 @@ def wrapped_program(example_idx, example):
226226
if return_outputs: # Handle the return_outputs logic
227227
results = [(example, prediction, score) for _, example, prediction, score in predicted_devset]
228228

229+
def prediction_is_dictlike(prediction):
230+
try:
231+
dict(prediction)
232+
return True
233+
except Exception:
234+
return False
235+
229236
data = [
230-
merge_dicts(example, prediction) | {"correct": score} for _, example, prediction, score in predicted_devset
237+
(
238+
merge_dicts(example, prediction) | {"correct": score}
239+
if prediction_is_dictlike(prediction)
240+
else dict(example) | {"prediction": prediction, "correct": score}
241+
)
242+
for _, example, prediction, score in predicted_devset
231243
]
232244

233245
result_df = pd.DataFrame(data)

tests/evaluate/test_evaluate.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import dspy
99
from dspy.evaluate.evaluate import Evaluate
1010
from dspy.evaluate.metrics import answer_exact_match
11+
from dspy.functional import TypedPredictor
1112
from dspy.predict import Predict
1213
from dspy.utils.dummies import DummyLM
1314

@@ -120,14 +121,38 @@ def test_evaluate_call_bad():
120121
assert score == 0.0
121122

122123

124+
@pytest.mark.parametrize(
125+
"program_with_example",
126+
[
127+
(Predict("question -> answer"), new_example("What is 1+1?", "2")),
128+
(
129+
# Create a program that extracts entities from text and returns them as a list,
130+
# rather than returning a Predictor() wrapper. This is done intentionally to test
131+
# the case where the program does not output a dictionary-like object because
132+
# Evaluate() has failed for this case in the past
133+
lambda text: TypedPredictor("text: str -> entities: List[str]")(text=text).entities,
134+
dspy.Example(text="United States", entities=["United States"]).with_inputs("text"),
135+
),
136+
],
137+
)
123138
@pytest.mark.parametrize("display_table", [True, False, 1])
124139
@pytest.mark.parametrize("is_in_ipython_notebook_environment", [True, False])
125-
def test_evaluate_display_table(display_table, is_in_ipython_notebook_environment, capfd):
126-
devset = [new_example("What is 1+1?", "2")]
127-
program = Predict("question -> answer")
140+
def test_evaluate_display_table(program_with_example, display_table, is_in_ipython_notebook_environment, capfd):
141+
program, example = program_with_example
142+
example_input = next(iter(example.inputs().values()))
143+
example_output = {key: value for key, value in example.toDict().items() if key not in example.inputs()}
144+
145+
dspy.settings.configure(
146+
lm=DummyLM(
147+
{
148+
example_input: example_output,
149+
}
150+
)
151+
)
152+
128153
ev = Evaluate(
129-
devset=devset,
130-
metric=answer_exact_match,
154+
devset=[example],
155+
metric=lambda example, pred, **kwargs: example == pred,
131156
display_table=display_table,
132157
)
133158
assert ev.display_table == display_table
@@ -140,4 +165,5 @@ def test_evaluate_display_table(display_table, is_in_ipython_notebook_environmen
140165
if not is_in_ipython_notebook_environment and display_table:
141166
# In console environments where IPython is not available, the table should be printed
142167
# to the console
143-
assert "What is 1+1?" in out
168+
example_input = next(iter(example.inputs().values()))
169+
assert example_input in out

0 commit comments

Comments
 (0)