Skip to content

Commit f1fc6bc

Browse files
authored
Support structured outputs response format based on signature in JSON adapter (#1881)
* Fix Signed-off-by: dbczumar <[email protected]> * Fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Debug Signed-off-by: dbczumar <[email protected]> * Fix Signed-off-by: dbczumar <[email protected]> * Here Signed-off-by: dbczumar <[email protected]> * Here Signed-off-by: dbczumar <[email protected]> * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py --------- Signed-off-by: dbczumar <[email protected]>
1 parent f11ccc0 commit f1fc6bc

File tree

4 files changed

+169
-7
lines changed

4 files changed

+169
-7
lines changed

dspy/adapters/json_adapter.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import enum
33
import inspect
44
import json
5+
import logging
56
import textwrap
7+
from copy import deepcopy
68
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin
79

810
import json_repair
911
import litellm
1012
import pydantic
11-
from pydantic import TypeAdapter
13+
from pydantic import TypeAdapter, create_model
1214
from pydantic.fields import FieldInfo
1315

1416
from dspy.adapters.base import Adapter
@@ -18,6 +20,8 @@
1820
from ..signatures.signature import SignatureMeta
1921
from ..signatures.utils import get_dspy_field_type
2022

23+
_logger = logging.getLogger(__name__)
24+
2125

2226
class FieldInfoWithName(NamedTuple):
2327
name: str
@@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
3539
try:
3640
provider = lm.model.split("/", 1)[0] or "openai"
3741
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
38-
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
42+
try:
43+
response_format = _get_structured_outputs_response_format(signature)
44+
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
45+
except Exception:
46+
_logger.debug(
47+
"Failed to obtain response using signature-based structured outputs"
48+
" response format: Falling back to default 'json_object' response format."
49+
" Exception: {e}"
50+
)
51+
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
3952
else:
4053
outputs = lm(**inputs, **lm_kwargs)
4154

@@ -303,3 +316,50 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
303316
# ", and then ending with the marker for `completed`.")
304317

305318
return "\n\n".join(parts).strip()
319+
320+
321+
def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel:
322+
"""
323+
Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from
324+
an LM request, based on the output fields of the specified DSPy signature.
325+
326+
Args:
327+
signature: The DSPy signature for which to obtain the `response_format` request parameter.
328+
Returns:
329+
A Pydantic model representing the `response_format` parameter for the LM request.
330+
"""
331+
332+
def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo:
333+
"""
334+
Recursively filter the `json_schema_extra` of a FieldInfo to exclude DSPy internal attributes
335+
(e.g. `__dspy_field_type`) and remove descriptions that are placeholders for the field name.
336+
"""
337+
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original
338+
339+
# Update `json_schema_extra` for the copied field
340+
if field_copy.json_schema_extra:
341+
field_copy.json_schema_extra = {
342+
key: value
343+
for key, value in field_info.json_schema_extra.items()
344+
if key not in ("desc", "__dspy_field_type")
345+
}
346+
field_desc = field_info.json_schema_extra.get("desc")
347+
if field_desc is not None and field_desc != f"${{{field_name}}}":
348+
field_copy.json_schema_extra["desc"] = field_desc
349+
350+
# Handle nested fields
351+
if hasattr(field_copy.annotation, "__pydantic_model__"):
352+
# Recursively update fields of the nested model
353+
nested_model = field_copy.annotation.__pydantic_model__
354+
updated_fields = {
355+
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items()
356+
}
357+
# Create a new model with the same name and updated fields
358+
field_copy.annotation = create_model(nested_model.__name__, **updated_fields)
359+
360+
return field_copy
361+
362+
output_pydantic_fields = {
363+
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items()
364+
}
365+
return create_model("DSPyProgramOutputs", **output_pydantic_fields)

tests/reliability/test_pydantic_models.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List
33

44
import pydantic
5+
import pytest
56

67
import dspy
78
from tests.reliability.utils import assert_program_output_correct, known_failing_models
@@ -33,22 +34,29 @@ class QA(dspy.Signature):
3334
assert_program_output_correct(
3435
program_input=question,
3536
program_output=answer.comments,
36-
grading_guidelines="The comments should be relevant to the answer",
37+
grading_guidelines=(
38+
"The comments should be relevant to the answer. They don't need to restate the answer explicitly."
39+
),
3740
)
3841
assert answer.certainty >= 0
3942
assert answer.certainty <= 1
4043
assert len(answer.comments) >= 2
4144

4245

43-
def test_color_classification_using_enum():
46+
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought])
47+
def test_color_classification_using_enum(module):
4448
Color = Enum("Color", ["RED", "GREEN", "BLUE"])
4549

4650
class Colorful(dspy.Signature):
4751
text: str = dspy.InputField()
4852
color: Color = dspy.OutputField()
4953

50-
program = dspy.Predict(Colorful)
51-
color = program(text="The sky is blue").color
54+
program = module(Colorful)
55+
# Note: The precise text, including the trailing period, is important here for ensuring that
56+
# the program is correctly extracting the color from the text; previous implementations have
57+
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum
58+
# responses for "The sky is blue" (without the period).
59+
color = program(text="The sky is blue.").color
5260

5361
assert color == Color.BLUE
5462

tests/reliability/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def assert_program_output_correct(
3131
grading_guidelines = [grading_guidelines]
3232

3333
with judge_dspy_configuration():
34-
print("GUIDELINES", grading_guidelines)
3534
for guideline_entry in grading_guidelines:
3635
judge_response = _get_judge_program()(
3736
program_input=str(program_input),

tests/test_json_adapter.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from unittest import mock
2+
3+
import pydantic
4+
import pytest
5+
from pydantic import create_model
6+
7+
import dspy
8+
9+
10+
def test_json_adapter_passes_structured_output_when_supported_by_model():
11+
class OutputField3(pydantic.BaseModel):
12+
subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10)
13+
subfield2: float = pydantic.Field(description="Float subfield 2")
14+
15+
class TestSignature(dspy.Signature):
16+
input1: str = dspy.InputField()
17+
output1: str = dspy.OutputField() # Description intentionally left blank
18+
output2: bool = dspy.OutputField(desc="Boolean output field")
19+
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
20+
output4_unannotated = dspy.OutputField(desc="Unannotated output field")
21+
22+
program = dspy.Predict(TestSignature)
23+
24+
# Configure DSPy to use an OpenAI LM that supports structured outputs
25+
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
26+
with mock.patch("litellm.completion") as mock_completion:
27+
program(input1="Test input")
28+
29+
def clean_schema_extra(field_name, field_info):
30+
attrs = dict(field_info.__repr_args__())
31+
if "json_schema_extra" in attrs:
32+
attrs["json_schema_extra"] = {
33+
k: v
34+
for k, v in attrs["json_schema_extra"].items()
35+
if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}")
36+
}
37+
return attrs
38+
39+
mock_completion.assert_called_once()
40+
_, call_kwargs = mock_completion.call_args
41+
response_format = call_kwargs.get("response_format")
42+
assert response_format is not None
43+
assert issubclass(response_format, pydantic.BaseModel)
44+
assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"}
45+
for field_name in response_format.model_fields:
46+
assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra(
47+
field_name=field_name,
48+
field_info=TestSignature.output_fields[field_name],
49+
)
50+
51+
# Configure DSPy to use a model from a fake provider that doesn't support structured outputs
52+
dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter())
53+
with mock.patch("litellm.completion") as mock_completion:
54+
program(input1="Test input")
55+
56+
mock_completion.assert_called_once()
57+
_, call_kwargs = mock_completion.call_args
58+
assert response_format not in call_kwargs
59+
60+
61+
def test_json_adapter_falls_back_when_structured_outputs_fails():
62+
class TestSignature(dspy.Signature):
63+
input1: str = dspy.InputField()
64+
output1: str = dspy.OutputField(desc="String output field")
65+
66+
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
67+
program = dspy.Predict(TestSignature)
68+
with mock.patch("litellm.completion") as mock_completion:
69+
mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value]
70+
program(input1="Test input")
71+
assert mock_completion.call_count == 2
72+
_, first_call_kwargs = mock_completion.call_args_list[0]
73+
assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel)
74+
_, second_call_kwargs = mock_completion.call_args_list[1]
75+
assert second_call_kwargs.get("response_format") == {"type": "json_object"}
76+
77+
78+
def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature():
79+
class OutputField3(pydantic.BaseModel):
80+
subfield1: int = pydantic.Field(description="Int subfield 1")
81+
subfield2: float = pydantic.Field(description="Float subfield 2")
82+
83+
class TestSignature(dspy.Signature):
84+
input1: str = dspy.InputField()
85+
output1: str = dspy.OutputField() # Description intentionally left blank
86+
output2: bool = dspy.OutputField(desc="Boolean output field")
87+
output3: OutputField3 = dspy.OutputField(desc="Nested output field")
88+
output4_unannotated = dspy.OutputField(desc="Unannotated output field")
89+
90+
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter())
91+
program = dspy.Predict(TestSignature)
92+
with mock.patch("litellm.completion"):
93+
program(input1="Test input")
94+
95+
assert program.signature.output_fields == TestSignature.output_fields

0 commit comments

Comments
 (0)