Skip to content

Commit fe3d9d1

Browse files
authored
Improve ChatAdapter's handling of typed values and Pydantic models (#1663)
* Improve ChatAdapter's handling of typed values and Pydantic models * Fixes for Literal * Fixes for formatting complex-typed values
1 parent 313aa66 commit fe3d9d1

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
1+
import re
12
import ast
23
import json
3-
import re
4+
import enum
5+
import inspect
6+
import pydantic
47
import textwrap
5-
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
68

7-
import pydantic
89
from pydantic import TypeAdapter
910
from pydantic.fields import FieldInfo
11+
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
1012

13+
from dspy.adapters.base import Adapter
1114
from ..signatures.field import OutputField
1215
from ..signatures.signature import SignatureMeta
1316
from ..signatures.utils import get_dspy_field_type
14-
from .base import Adapter
1517

1618
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
1719

1820

1921
class FieldInfoWithName(NamedTuple):
20-
"""
21-
A tuple containing a field name and its corresponding FieldInfo object.
22-
"""
23-
2422
name: str
2523
info: FieldInfo
2624

2725

28-
# Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat
29-
# thread) has been completed.
26+
# Built-in field indicating that a chat turn has been completed.
3027
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())
3128

3229

@@ -114,6 +111,16 @@ def format_input_list_field_value(value: List[Any]) -> str:
114111
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
115112

116113

114+
def _serialize_for_json(value):
115+
if isinstance(value, pydantic.BaseModel):
116+
return value.model_dump()
117+
elif isinstance(value, list):
118+
return [_serialize_for_json(item) for item in value]
119+
elif isinstance(value, dict):
120+
return {key: _serialize_for_json(val) for key, val in value.items()}
121+
else:
122+
return value
123+
117124
def _format_field_value(field_info: FieldInfo, value: Any) -> str:
118125
"""
119126
Formats the value of the specified field according to the field's DSPy type (input or output),
@@ -125,24 +132,17 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
125132
Returns:
126133
The formatted value of the field, represented as a string.
127134
"""
128-
dspy_field_type: Literal["input", "output"] = get_dspy_field_type(field_info)
129-
if isinstance(value, list):
130-
if dspy_field_type == "input" or field_info.annotation is str:
131-
# If the field is an input field or has no special type requirements, format it as
132-
# numbered list so that it's organized in a way suitable for presenting long context
133-
# to an LLM (i.e. not JSON)
134-
return format_input_list_field_value(value)
135-
else:
136-
# If the field is an output field that has strict parsing requirements, format the
137-
# value as a stringified JSON Array. This ensures that downstream routines can parse
138-
# the field value correctly using methods from the `ujson` or `json` packages.
139-
return json.dumps(value)
140-
elif isinstance(value, pydantic.BaseModel):
141-
return value.model_dump_json()
135+
136+
if isinstance(value, list) and field_info.annotation is str:
137+
# If the field has no special type requirements, format it as a nice numbere list for the LM.
138+
return format_input_list_field_value(value)
139+
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
140+
return json.dumps(_serialize_for_json(value))
142141
else:
143142
return str(value)
144143

145144

145+
146146
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
147147
"""
148148
Formats the values of the specified fields according to the field's DSPy type (input or output),
@@ -166,15 +166,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
166166
def parse_value(value, annotation):
167167
if annotation is str:
168168
return str(value)
169+
169170
parsed_value = value
170-
if isinstance(value, str):
171+
172+
if isinstance(annotation, enum.EnumMeta):
173+
parsed_value = annotation[value]
174+
elif isinstance(value, str):
171175
try:
172176
parsed_value = json.loads(value)
173177
except json.JSONDecodeError:
174178
try:
175179
parsed_value = ast.literal_eval(value)
176180
except (ValueError, SyntaxError):
177181
parsed_value = value
182+
178183
return TypeAdapter(annotation).validate_python(parsed_value)
179184

180185

@@ -222,6 +227,16 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
222227
content.append(formatted_fields)
223228

224229
if role == "user":
230+
# def type_info(v):
231+
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
232+
# if v.annotation is not str else ""
233+
#
234+
# content.append(
235+
# "Respond with the corresponding output fields, starting with the field "
236+
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
237+
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
238+
# )
239+
225240
content.append(
226241
"Respond with the corresponding output fields, starting with the field "
227242
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
@@ -260,10 +275,30 @@ def prepare_instructions(signature: SignatureMeta):
260275
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
261276
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
262277

278+
def field_metadata(field_name, field_info):
279+
type_ = field_info.annotation
280+
281+
if get_dspy_field_type(field_info) == 'input' or type_ is str:
282+
desc = ""
283+
elif type_ is bool:
284+
desc = "must be True or False"
285+
elif type_ in (int, float):
286+
desc = f"must be a single {type_.__name__} value"
287+
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
288+
desc= f"must be one of: {'; '.join(type_.__members__)}"
289+
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
290+
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
291+
else:
292+
desc = "must be pareseable according to the following JSON schema: "
293+
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())
294+
295+
desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
296+
return f"{{{field_name}}}{desc}"
297+
263298
def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
264299
return format_fields(
265300
fields_with_values={
266-
FieldInfoWithName(name=field_name, info=field_info): f"{{{field_name}}}"
301+
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
267302
for field_name, field_info in fields.items()
268303
}
269304
)

0 commit comments

Comments
 (0)