Skip to content

Commit 2349c84

Browse files
authored
Fix TypedPredictor formatting with list output values (#1609)
* Changes and lint Signed-off-by: dbczumar <[email protected]> * Docstring Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Fix test failures Signed-off-by: dbczumar <[email protected]> * Remove debug Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Test Signed-off-by: dbczumar <[email protected]> * name Signed-off-by: dbczumar <[email protected]> * quote Signed-off-by: dbczumar <[email protected]> * Update test_functional.py --------- Signed-off-by: dbczumar <[email protected]>
1 parent 24ab964 commit 2349c84

File tree

6 files changed

+198
-50
lines changed

6 files changed

+198
-50
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,34 @@
22
import json
33
import re
44
import textwrap
5-
from typing import get_args, get_origin
5+
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
66

77
import pydantic
88
from pydantic import TypeAdapter
9+
from pydantic.fields import FieldInfo
910

11+
from ..signatures.field import OutputField
12+
from ..signatures.signature import SignatureMeta
13+
from ..signatures.utils import get_dspy_field_type
1014
from .base import Adapter
1115

1216
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
1317

1418

19+
class FieldInfoWithName(NamedTuple):
20+
"""
21+
A tuple containing a field name and its corresponding FieldInfo object.
22+
"""
23+
24+
name: str
25+
info: FieldInfo
26+
27+
28+
# Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat
29+
# thread) has been completed.
30+
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())
31+
32+
1533
class ChatAdapter(Adapter):
1634
def __init__(self):
1735
pass
@@ -79,29 +97,68 @@ def format_blob(blob):
7997
return f"«««\n {modified_blob}\n»»»"
8098

8199

82-
def format_list(items):
83-
if len(items) == 0:
100+
def format_input_list_field_value(value: List[Any]) -> str:
101+
"""
102+
Formats the value of an input field of type List[Any].
103+
104+
Args:
105+
value: The value of the list-type input field.
106+
Returns:
107+
A string representation of the input field's list value.
108+
"""
109+
if len(value) == 0:
84110
return "N/A"
85-
if len(items) == 1:
86-
return format_blob(items[0])
111+
if len(value) == 1:
112+
return format_blob(value[0])
87113

88-
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)])
114+
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
89115

90116

91-
def _format_field_value(value) -> str:
117+
def _format_field_value(field_info: FieldInfo, value: Any) -> str:
118+
"""
119+
Formats the value of the specified field according to the field's DSPy type (input or output),
120+
annotation (e.g. str, int, etc.), and the type of the value itself.
121+
122+
Args:
123+
field_info: Information about the field, including its DSPy field type and annotation.
124+
value: The value of the field.
125+
Returns:
126+
The formatted value of the field, represented as a string.
127+
"""
128+
dspy_field_type: Literal["input", "output"] = get_dspy_field_type(field_info)
92129
if isinstance(value, list):
93-
return format_list(value)
130+
if dspy_field_type == "input" or field_info.annotation is str:
131+
# If the field is an input field or has no special type requirements, format it as
132+
# numbered list so that it's organized in a way suitable for presenting long context
133+
# to an LLM (i.e. not JSON)
134+
return format_input_list_field_value(value)
135+
else:
136+
# If the field is an output field that has strict parsing requirements, format the
137+
# value as a stringified JSON Array. This ensures that downstream routines can parse
138+
# the field value correctly using methods from the `ujson` or `json` packages.
139+
return json.dumps(value)
94140
elif isinstance(value, pydantic.BaseModel):
95141
return value.model_dump_json()
96142
else:
97143
return str(value)
98144

99145

100-
def format_fields(fields):
146+
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
147+
"""
148+
Formats the values of the specified fields according to the field's DSPy type (input or output),
149+
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
150+
into a single string, which is is a multiline string if there are multiple fields.
151+
152+
Args:
153+
fields_with_values: A dictionary mapping information about a field to its corresponding
154+
value.
155+
Returns:
156+
The joined formatted values of the fields, represented as a string.
157+
"""
101158
output = []
102-
for k, v in fields.items():
103-
v = _format_field_value(v)
104-
output.append(f"[[ ## {k} ## ]]\n{v}")
159+
for field, field_value in fields_with_values.items():
160+
formatted_field_value = _format_field_value(field_info=field.info, value=field_value)
161+
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
105162

106163
return "\n\n".join(output).strip()
107164

@@ -121,21 +178,48 @@ def parse_value(value, annotation):
121178
return TypeAdapter(annotation).validate_python(parsed_value)
122179

123180

124-
def format_turn(signature, values, role, incomplete=False):
181+
def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomplete=False) -> Dict[str, str]:
182+
"""
183+
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
184+
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
185+
186+
Args:
187+
signature: The DSPy signature to which future LLM responses should conform.
188+
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
189+
that should be included in the message.
190+
role: The role of the message, which can be either "user" or "assistant".
191+
incomplete: If True, indicates that output field values are present in the set of specified
192+
``values``. If False, indicates that ``values`` only contains input field values.
193+
Returns:
194+
A chat message that can be appended to a chat thread. The message contains two string fields:
195+
``role`` ("user" or "assistant") and ``content`` (the message text).
196+
"""
125197
content = []
126198

127199
if role == "user":
128-
field_names = signature.input_fields.keys()
200+
fields: Dict[str, FieldInfo] = signature.input_fields
129201
if incomplete:
130202
content.append("This is an example of the task, though some input or output fields are not supplied.")
131203
else:
132-
field_names, values = list(signature.output_fields.keys()) + ["completed"], {**values, "completed": ""}
204+
fields: Dict[str, FieldInfo] = signature.output_fields
205+
# Add the built-in field indicating that the chat turn has been completed
206+
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
207+
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
133208

134209
if not incomplete:
210+
field_names: KeysView = fields.keys()
135211
if not set(values).issuperset(set(field_names)):
136212
raise ValueError(f"Expected {field_names} but got {values.keys()}")
137213

138-
content.append(format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names}))
214+
formatted_fields = format_fields(
215+
fields_with_values={
216+
FieldInfoWithName(name=field_name, info=field_info): values.get(
217+
field_name, "Not supplied for this particular example."
218+
)
219+
for field_name, field_info in fields.items()
220+
}
221+
)
222+
content.append(formatted_fields)
139223

140224
if role == "user":
141225
content.append(
@@ -170,15 +254,23 @@ def enumerate_fields(fields):
170254
return "\n".join(parts).strip()
171255

172256

173-
def prepare_instructions(signature):
257+
def prepare_instructions(signature: SignatureMeta):
174258
parts = []
175259
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
176260
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
177261
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
178262

179-
parts.append(format_fields({f: f"{{{f}}}" for f in signature.input_fields}))
180-
parts.append(format_fields({f: f"{{{f}}}" for f in signature.output_fields}))
181-
parts.append(format_fields({"completed": ""}))
263+
def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
264+
return format_fields(
265+
fields_with_values={
266+
FieldInfoWithName(name=field_name, info=field_info): f"{{{field_name}}}"
267+
for field_name, field_info in fields.items()
268+
}
269+
)
270+
271+
parts.append(format_signature_fields_for_instructions(signature.input_fields))
272+
parts.append(format_signature_fields_for_instructions(signature.output_fields))
273+
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
182274

183275
instructions = textwrap.dedent(signature.instructions)
184276
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())

dspy/predict/predict.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def _load_state_legacy(self, state):
114114
*_, last_key = self.extended_signature.fields.keys()
115115
self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix)
116116

117-
118117
def __call__(self, **kwargs):
119118
return self.forward(**kwargs)
120119

@@ -148,15 +147,18 @@ def forward(self, **kwargs):
148147
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")
149148

150149
import dspy
150+
151151
if isinstance(lm, dspy.LM):
152152
completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)
153153
else:
154-
warn_once("\t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated. ***\n"
155-
f" \t\tYou are using the client {lm.__class__.__name__}, which will be removed in DSPy 2.6.\n"
156-
" \t\tChanging the client is straightforward and will let you use new features (Adapters) that"
157-
" improve the consistency of LM outputs, especially when using chat LMs. \n\n"
158-
" \t\tLearn more about the changes and how to migrate at\n"
159-
" \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb")
154+
warn_once(
155+
"\t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated. ***\n"
156+
f" \t\tYou are using the client {lm.__class__.__name__}, which will be removed in DSPy 2.6.\n"
157+
" \t\tChanging the client is straightforward and will let you use new features (Adapters) that"
158+
" improve the consistency of LM outputs, especially when using chat LMs. \n\n"
159+
" \t\tLearn more about the changes and how to migrate at\n"
160+
" \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb"
161+
)
160162

161163
if dsp.settings.experimental:
162164
completions = new_generate(lm, signature, dsp.Example(demos=demos, **kwargs), **config)
@@ -181,7 +183,6 @@ def __repr__(self):
181183
return f"{self.__class__.__name__}({self.signature})"
182184

183185

184-
185186
def old_generate(demos, signature, kwargs, config, lm, stage):
186187
# Switch to legacy format for dsp.generate
187188
x = dsp.Example(demos=demos, **kwargs)
@@ -208,7 +209,7 @@ def old_generate(demos, signature, kwargs, config, lm, stage):
208209

209210

210211
def new_generate(lm, signature, example, max_depth=6, **kwargs):
211-
kwargs['stop'] = tuple(kwargs.get('stop', [])) or ('\n---', )
212+
kwargs["stop"] = tuple(kwargs.get("stop", [])) or ("\n---",)
212213

213214
# Generate and extract the fields.
214215
template = signature_to_template(signature, adapter=dsp.ExperimentalAdapter)
@@ -223,22 +224,28 @@ def new_generate(lm, signature, example, max_depth=6, **kwargs):
223224
for field_idx, key in enumerate(field_names):
224225
completions_ = [c for c in completions if key in c.keys() and c[key] is not None]
225226
completions = completions_ or completions
226-
if len(completions_) == 0: break
227+
if len(completions_) == 0:
228+
break
227229

228230
# If none of the completions is completed (i.e., none has the final field set).
229231
if len(completions_) == 0:
230232
# Pick the first completion that has gone farthest.
231233
completion = completions[0]
232234

233-
for field_idx_ in range(field_idx+1, len(field_names)):
234-
if field_names[field_idx_] in completion: del completion[field_names[field_idx_]]
235+
for field_idx_ in range(field_idx + 1, len(field_names)):
236+
if field_names[field_idx_] in completion:
237+
del completion[field_names[field_idx_]]
235238

236239
# Recurse with greedy decoding.
237-
new_kwargs = {**kwargs, "n": 1, "temperature": 0.0,}
240+
new_kwargs = {
241+
**kwargs,
242+
"n": 1,
243+
"temperature": 0.0,
244+
}
238245

239246
assert max_depth > 0
240-
return new_generate(lm, signature, completion, max_depth=max_depth-1, **new_kwargs)
241-
247+
return new_generate(lm, signature, completion, max_depth=max_depth - 1, **new_kwargs)
248+
242249
# Keep only output fields.
243250
completions = [{k: v for k, v in c.items() if k in signature.output_fields} for c in completions]
244251

@@ -247,14 +254,16 @@ def new_generate(lm, signature, example, max_depth=6, **kwargs):
247254

248255
def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
249256
import dspy
257+
250258
adapter = dspy.settings.adapter or dspy.ChatAdapter()
251259

252-
return adapter(lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values)
253-
260+
return adapter(
261+
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
262+
)
254263

255264

256265
# TODO: get some defaults during init from the context window?
257266
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
258267
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
259268
# Generally, unless overwritten, we'd see n=None, temperature=None.
260-
# That will eventually mean we have to learn them.
269+
# That will eventually mean we have to learn them.

dspy/signatures/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Literal
2+
3+
from pydantic.fields import FieldInfo
4+
5+
6+
def get_dspy_field_type(field: FieldInfo) -> Literal["input", "output"]:
7+
field_type = field.json_schema_extra.get("__dspy_field_type")
8+
if field_type is None:
9+
raise ValueError(f"Field {field} does not have a __dspy_field_type")
10+
return field_type

dspy/utils/dummies.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import random
22
import re
33
from collections import defaultdict
4-
from typing import Union
4+
from typing import Any, Dict, Union
55

66
import numpy as np
77

88
from dsp.modules import LM as DSPLM
99
from dsp.utils.utils import dotdict
10-
from dspy.adapters.chat_adapter import field_header_pattern, format_fields
10+
from dspy.adapters.chat_adapter import FieldInfoWithName, field_header_pattern, format_fields
1111
from dspy.clients.lm import LM
12+
from dspy.signatures.field import OutputField
1213

1314

1415
class DSPDummyLM(DSPLM):
@@ -170,6 +171,14 @@ def _use_example(self, messages):
170171
return output["content"]
171172

172173
def __call__(self, prompt=None, messages=None, **kwargs):
174+
def format_answer_fields(field_names_and_values: Dict[str, Any]):
175+
return format_fields(
176+
fields_with_values={
177+
FieldInfoWithName(name=field_name, info=OutputField()): value
178+
for field_name, value in field_names_and_values.items()
179+
}
180+
)
181+
173182
# Build the request.
174183
outputs = []
175184
for _ in range(kwargs.get("n", 1)):
@@ -181,12 +190,12 @@ def __call__(self, prompt=None, messages=None, **kwargs):
181190
elif isinstance(self.answers, dict):
182191
outputs.append(
183192
next(
184-
(format_fields(v) for k, v in self.answers.items() if k in messages[-1]["content"]),
193+
(format_answer_fields(v) for k, v in self.answers.items() if k in messages[-1]["content"]),
185194
"No more responses",
186195
)
187196
)
188197
else:
189-
outputs.append(format_fields(next(self.answers, {"answer": "No more responses"})))
198+
outputs.append(format_answer_fields(next(self.answers, {"answer": "No more responses"})))
190199

191200
# Logging, with removed api key & where `cost` is None on cache hit.
192201
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}

tests/functional/test_functional.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import json
23
import textwrap
34
from typing import Annotated, Any, Generic, List, Literal, Optional, TypeVar
45

@@ -452,6 +453,17 @@ class TestSignature(dspy.Signature):
452453
assert output == [0, 1, 2]
453454

454455

456+
def test_list_inputs_and_outputs():
457+
lm = DummyLM([{"output": '["0", "1", "2"]'}])
458+
dspy.settings.configure(lm=lm)
459+
460+
test = TypedPredictor("input:list[str] -> output:list[str]")
461+
output = test(input=["3", "4", "5"]).completions.output[0]
462+
463+
# Verify that the format of the output list from the LM was not changed
464+
assert output == ["0", "1", "2"]
465+
466+
455467
def test_multiple_outputs_int_cot():
456468
# Note: Multiple outputs only work when the language model "speculatively" generates all the outputs in one go.
457469
lm = DummyLM(

0 commit comments

Comments
 (0)