Skip to content

Commit d646cb2

Browse files
committed
Adapter: Don't return logprobs unless requested
1 parent 469d037 commit d646cb2

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

dspy/adapters/base.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,21 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
2222

2323
try:
2424
for output in outputs:
25-
if type(output) is str:
26-
output_text, output_logprobs = output, None
27-
elif type(output) is dict:
28-
output_text, output_logprobs = output["text"], output["logprobs"]
29-
else:
30-
raise ValueError(f"Expected str or dict but got {type(output)}")
31-
value = self.parse(signature, output_text, _parse_values=_parse_values)
32-
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
33-
value["logprobs"] = output_logprobs
25+
output_logprobs = None
26+
27+
if isinstance(output, dict):
28+
output, output_logprobs = output["text"], output["logprobs"]
29+
30+
value = self.parse(signature, output, _parse_values=_parse_values)
31+
32+
assert set(value.keys()) == set(signature.output_fields.keys()), \
33+
f"Expected {signature.output_fields.keys()} but got {value.keys()}"
34+
35+
if output_logprobs is not None:
36+
value["logprobs"] = output_logprobs
37+
3438
values.append(value)
39+
3540
return values
3641

3742
except Exception as e:

0 commit comments

Comments
 (0)