Skip to content

Commit c4f1f95

Browse files
authored
Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) (#1853)
* Add Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> --------- Signed-off-by: dbczumar <[email protected]>
1 parent 1b10e23 commit c4f1f95

File tree

8 files changed

+304
-170
lines changed

8 files changed

+304
-170
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 3 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from dsp.adapters.base_template import Field
1616
from dspy.adapters.base import Adapter
17-
from dspy.adapters.image_utils import Image, encode_image
17+
from dspy.adapters.utils import find_enum_member, format_field_value
1818
from dspy.signatures.field import OutputField
1919
from dspy.signatures.signature import Signature, SignatureMeta
2020
from dspy.signatures.utils import get_dspy_field_type
@@ -115,86 +115,6 @@ def format_fields(self, signature, values, role):
115115
return format_fields(fields_with_values)
116116

117117

118-
def format_blob(blob):
119-
if "\n" not in blob and "«" not in blob and "»" not in blob:
120-
return f"«{blob}»"
121-
122-
modified_blob = blob.replace("\n", "\n ")
123-
return f"«««\n {modified_blob}\n»»»"
124-
125-
126-
def format_input_list_field_value(value: List[Any]) -> str:
127-
"""
128-
Formats the value of an input field of type List[Any].
129-
130-
Args:
131-
value: The value of the list-type input field.
132-
Returns:
133-
A string representation of the input field's list value.
134-
"""
135-
if len(value) == 0:
136-
return "N/A"
137-
if len(value) == 1:
138-
return format_blob(value[0])
139-
140-
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
141-
142-
143-
def _serialize_for_json(value):
144-
if isinstance(value, pydantic.BaseModel):
145-
return value.model_dump()
146-
elif isinstance(value, list):
147-
return [_serialize_for_json(item) for item in value]
148-
elif isinstance(value, dict):
149-
return {key: _serialize_for_json(val) for key, val in value.items()}
150-
else:
151-
return value
152-
153-
154-
def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]:
155-
"""
156-
Formats the value of the specified field according to the field's DSPy type (input or output),
157-
annotation (e.g. str, int, etc.), and the type of the value itself.
158-
159-
Args:
160-
field_info: Information about the field, including its DSPy field type and annotation.
161-
value: The value of the field.
162-
Returns:
163-
The formatted value of the field, represented as a string.
164-
"""
165-
string_value = None
166-
if isinstance(value, list) and field_info.annotation is str:
167-
# If the field has no special type requirements, format it as a nice numbered list for the LM.
168-
string_value = format_input_list_field_value(value)
169-
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
170-
string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False)
171-
else:
172-
string_value = str(value)
173-
174-
if assume_text:
175-
return string_value
176-
elif isinstance(value, Image) or field_info.annotation == Image:
177-
# This validation should happen somewhere else
178-
# Safe to import PIL here because it's only imported when an image is actually being formatted
179-
try:
180-
import PIL
181-
except ImportError:
182-
raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
183-
image_value = value
184-
if not isinstance(image_value, Image):
185-
if isinstance(image_value, dict) and "url" in image_value:
186-
image_value = image_value["url"]
187-
elif isinstance(image_value, str):
188-
image_value = encode_image(image_value)
189-
elif isinstance(image_value, PIL.Image.Image):
190-
image_value = encode_image(image_value)
191-
assert isinstance(image_value, str)
192-
image_value = Image(url=image_value)
193-
return {"type": "image_url", "image_url": image_value.model_dump()}
194-
else:
195-
return {"type": "text", "text": string_value}
196-
197-
198118
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
199119
"""
200120
Formats the values of the specified fields according to the field's DSPy type (input or output),
@@ -209,7 +129,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
209129
"""
210130
output = []
211131
for field, field_value in fields_with_values.items():
212-
formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
132+
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
213133
if assume_text:
214134
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
215135
else:
@@ -231,7 +151,7 @@ def parse_value(value, annotation):
231151
parsed_value = value
232152

233153
if isinstance(annotation, enum.EnumMeta):
234-
parsed_value = annotation[value]
154+
return find_enum_member(annotation, value)
235155
elif isinstance(value, str):
236156
try:
237157
parsed_value = json.loads(value)

dspy/adapters/json_adapter.py

Lines changed: 41 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
import ast
2-
import json
32
import enum
43
import inspect
5-
import litellm
6-
import pydantic
4+
import json
75
import textwrap
8-
import json_repair
9-
6+
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin
107

8+
import json_repair
9+
import litellm
10+
import pydantic
1111
from pydantic import TypeAdapter
1212
from pydantic.fields import FieldInfo
13-
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
1413

1514
from dspy.adapters.base import Adapter
15+
from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json
16+
1617
from ..adapters.image_utils import Image
1718
from ..signatures.signature import SignatureMeta
1819
from ..signatures.utils import get_dspy_field_type
1920

21+
2022
class FieldInfoWithName(NamedTuple):
2123
name: str
2224
info: FieldInfo
2325

26+
2427
class JSONAdapter(Adapter):
2528
def __init__(self):
2629
pass
2730

2831
def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
2932
inputs = self.format(signature, demos, inputs)
3033
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
31-
32-
34+
3335
try:
34-
provider = lm.model.split('/', 1)[0] or "openai"
35-
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
36-
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" })
36+
provider = lm.model.split("/", 1)[0] or "openai"
37+
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
38+
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
3739
else:
3840
outputs = lm(**inputs, **lm_kwargs)
3941

@@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
4446

4547
for output in outputs:
4648
value = self.parse(signature, output, _parse_values=_parse_values)
47-
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
49+
assert set(value.keys()) == set(
50+
signature.output_fields.keys()
51+
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
4852
values.append(value)
49-
50-
return values
5153

54+
return values
5255

5356
def format(self, signature, demos, inputs):
5457
messages = []
@@ -71,7 +74,7 @@ def format(self, signature, demos, inputs):
7174
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
7275

7376
messages.append(format_turn(signature, inputs, role="user"))
74-
77+
7578
return messages
7679

7780
def parse(self, signature, completion, _parse_values=True):
@@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True):
9093

9194
def format_turn(self, signature, values, role, incomplete=False):
9295
return format_turn(signature, values, role, incomplete)
93-
96+
9497
def format_fields(self, signature, values, role):
9598
fields_with_values = {
9699
FieldInfoWithName(name=field_name, info=field_info): values.get(
@@ -101,16 +104,16 @@ def format_fields(self, signature, values, role):
101104
}
102105

103106
return format_fields(role=role, fields_with_values=fields_with_values)
104-
107+
105108

106109
def parse_value(value, annotation):
107110
if annotation is str:
108111
return str(value)
109-
112+
110113
parsed_value = value
111114

112115
if isinstance(annotation, enum.EnumMeta):
113-
parsed_value = annotation[value]
116+
parsed_value = find_enum_member(annotation, value)
114117
elif isinstance(value, str):
115118
try:
116119
parsed_value = json.loads(value)
@@ -119,45 +122,10 @@ def parse_value(value, annotation):
119122
parsed_value = ast.literal_eval(value)
120123
except (ValueError, SyntaxError):
121124
parsed_value = value
122-
123-
return TypeAdapter(annotation).validate_python(parsed_value)
124-
125125

126-
def format_blob(blob):
127-
if "\n" not in blob and "«" not in blob and "»" not in blob:
128-
return f"«{blob}»"
129-
130-
modified_blob = blob.replace("\n", "\n ")
131-
return f"«««\n {modified_blob}\n»»»"
132-
133-
134-
def format_input_list_field_value(value: List[Any]) -> str:
135-
"""
136-
Formats the value of an input field of type List[Any].
137-
138-
Args:
139-
value: The value of the list-type input field.
140-
Returns:
141-
A string representation of the input field's list value.
142-
"""
143-
if len(value) == 0:
144-
return "N/A"
145-
if len(value) == 1:
146-
return format_blob(value[0])
147-
148-
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
126+
return TypeAdapter(annotation).validate_python(parsed_value)
149127

150128

151-
def _serialize_for_json(value):
152-
if isinstance(value, pydantic.BaseModel):
153-
return value.model_dump()
154-
elif isinstance(value, list):
155-
return [_serialize_for_json(item) for item in value]
156-
elif isinstance(value, dict):
157-
return {key: _serialize_for_json(val) for key, val in value.items()}
158-
else:
159-
return value
160-
161129
def _format_field_value(field_info: FieldInfo, value: Any) -> str:
162130
"""
163131
Formats the value of the specified field according to the field's DSPy type (input or output),
@@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
169137
Returns:
170138
The formatted value of the field, represented as a string.
171139
"""
172-
173-
if isinstance(value, list) and field_info.annotation is str:
174-
# If the field has no special type requirements, format it as a nice numbere list for the LM.
175-
return format_input_list_field_value(value)
176140
if field_info.annotation is Image:
177141
raise NotImplementedError("Images are not yet supported in JSON mode.")
178-
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
179-
return json.dumps(_serialize_for_json(value))
180-
else:
181-
return str(value)
182142

143+
return format_field_value(field_info=field_info, value=value, assume_text=True)
183144

184145

185146
def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
@@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
197158

198159
if role == "assistant":
199160
d = fields_with_values.items()
200-
d = {k.name: _serialize_for_json(v) for k, v in d}
201-
202-
return json.dumps(_serialize_for_json(d), indent=2)
161+
d = {k.name: v for k, v in d}
162+
return json.dumps(serialize_for_json(d), indent=2)
203163

204164
output = []
205165
for field, field_value in fields_with_values.items():
@@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
246206
field_name, "Not supplied for this particular example."
247207
)
248208
for field_name, field_info in fields.items()
249-
}
209+
},
250210
)
251211
content.append(formatted_fields)
252212

253213
if role == "user":
214+
254215
def type_info(v):
255-
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
256-
if v.annotation is not str else ""
257-
216+
return (
217+
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
218+
if v.annotation is not str
219+
else ""
220+
)
221+
258222
# TODO: Consider if not incomplete:
259223
content.append(
260224
"Respond with a JSON object in the following order of fields: "
@@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta):
297261
def field_metadata(field_name, field_info):
298262
type_ = field_info.annotation
299263

300-
if get_dspy_field_type(field_info) == 'input' or type_ is str:
264+
if get_dspy_field_type(field_info) == "input" or type_ is str:
301265
desc = ""
302266
elif type_ is bool:
303267
desc = "must be True or False"
304268
elif type_ in (int, float):
305269
desc = f"must be a single {type_.__name__} value"
306270
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
307-
desc= f"must be one of: {'; '.join(type_.__members__)}"
308-
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
271+
desc = f"must be one of: {'; '.join(type_.__members__)}"
272+
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal:
309273
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
310274
else:
311275
desc = "must be pareseable according to the following JSON schema: "
@@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
320284
fields_with_values={
321285
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
322286
for field_name, field_info in fields.items()
323-
}
287+
},
324288
)
325-
289+
326290
parts.append("Inputs will have the following structure:")
327-
parts.append(format_signature_fields_for_instructions('user', signature.input_fields))
291+
parts.append(format_signature_fields_for_instructions("user", signature.input_fields))
328292
parts.append("Outputs will be a JSON object with the following fields.")
329-
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields))
293+
parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields))
330294
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
331295

332296
instructions = textwrap.dedent(signature.instructions)

0 commit comments

Comments
 (0)