Skip to content

Commit 43e241e

Browse files
Fix issues in BAMLAdapter (#8654)
* init * increment * fix various issues on bamlAdapter * lint
1 parent 84a4b7e commit 43e241e

File tree

2 files changed

+74
-178
lines changed

2 files changed

+74
-178
lines changed

dspy/adapters/baml_adapter.py

Lines changed: 67 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -17,141 +17,117 @@
1717
COMMENT_SYMBOL = "#"
1818

1919

20-
def _render_type_str(annotation: Any, _depth: int = 0, indent: int = 0) -> str:
20+
def _render_type_str(
21+
annotation: Any,
22+
depth: int = 0,
23+
indent: int = 0,
24+
seen_models: set[type] | None = None,
25+
) -> str:
2126
"""Recursively renders a type annotation into a simplified string.
2227
2328
Args:
2429
annotation: The type annotation to render
25-
_depth: Current recursion depth (prevents infinite recursion)
30+
depth: Current recursion depth (prevents infinite recursion)
2631
indent: Current indentation level for nested structures
2732
"""
28-
max_depth = 10
29-
if _depth > max_depth: # Prevent excessive recursion
30-
return f"<max depth of {max_depth} exceeded>"
33+
# Non-nested types
34+
if annotation is str:
35+
return "string"
36+
if annotation is int:
37+
return "int"
38+
if annotation is float:
39+
return "float"
40+
if annotation is bool:
41+
return "boolean"
42+
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
43+
return _build_simplified_schema(annotation, indent, seen_models)
3144

3245
try:
3346
origin = get_origin(annotation)
3447
args = get_args(annotation)
3548
except Exception:
3649
return str(annotation)
3750

38-
# Handle Optional[T] or T | None
51+
# Optional[T] or T | None
3952
if origin in (types.UnionType, Union):
4053
non_none_args = [arg for arg in args if arg is not type(None)]
4154
# Render the non-None part of the union
42-
type_render = " or ".join([_render_type_str(arg, _depth + 1, indent) for arg in non_none_args])
43-
# Add 'or null' if None was part of the union
55+
type_render = " or ".join([_render_type_str(arg, depth + 1, indent) for arg in non_none_args])
56+
# Add "or null" if None was part of the union
4457
if len(non_none_args) < len(args):
4558
return f"{type_render} or null"
4659
return type_render
4760

48-
# Base types
49-
if annotation is str:
50-
return "string"
51-
if annotation is int:
52-
return "int"
53-
if annotation is float:
54-
return "float"
55-
if annotation is bool:
56-
return "boolean"
57-
58-
# Composite types
61+
# Literal[T1, T2, ...]
5962
if origin is Literal:
6063
return " or ".join(f'"{arg}"' for arg in args)
64+
65+
# list[T]
6166
if origin is list:
6267
# For Pydantic models in lists, use bracket notation
6368
inner_type = args[0]
6469
if inspect.isclass(inner_type) and issubclass(inner_type, BaseModel):
6570
# Build inner schema - the Pydantic model inside should use indent level for array contents
66-
inner_schema = _build_simplified_schema(inner_type, indent + 1)
71+
inner_schema = _build_simplified_schema(inner_type, indent + 1, seen_models)
6772
# Format with proper bracket notation and indentation
6873
current_indent = " " * indent
6974
return f"[\n{inner_schema}\n{current_indent}]"
7075
else:
71-
return f"{_render_type_str(inner_type, _depth + 1, indent)}[]"
72-
if origin is dict:
73-
return f"dict[{_render_type_str(args[0], _depth + 1, indent)}, {_render_type_str(args[1], _depth + 1, indent)}]"
76+
return f"{_render_type_str(inner_type, depth + 1, indent)}[]"
7477

75-
# Pydantic models (we'll recurse in the main function)
76-
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
77-
try:
78-
return _build_simplified_schema(annotation, indent)
79-
except Exception:
80-
return f"<{annotation.__name__}>"
78+
# dict[T1, T2]
79+
if origin is dict:
80+
return f"dict[{_render_type_str(args[0], depth + 1, indent)}, {_render_type_str(args[1], depth + 1, indent)}]"
8181

82-
# Fallback
82+
# fallback
8383
if hasattr(annotation, "__name__"):
8484
return annotation.__name__
8585
return str(annotation)
8686

8787

88-
def _build_simplified_schema(model: type[BaseModel], indent: int = 0, _seen: set[type] | None = None) -> str:
88+
def _build_simplified_schema(
89+
pydantic_model: type[BaseModel],
90+
indent: int = 0,
91+
seen_models: set[type] | None = None,
92+
) -> str:
8993
"""Builds a simplified, human-readable schema from a Pydantic model.
9094
9195
Args:
92-
model: The Pydantic model to build schema for
96+
pydantic_model: The Pydantic model to build schema for
9397
indent: Current indentation level
94-
_seen: Set to track visited models (prevents infinite recursion)
98+
seen_models: Set to track visited pydantic models (prevents infinite recursion)
9599
"""
96-
if _seen is None:
97-
_seen = set()
100+
seen_models = seen_models or set()
98101

99-
if model in _seen:
100-
return f"<circular reference to {model.__name__}>"
102+
if pydantic_model in seen_models:
103+
raise ValueError("BAMLAdapter cannot handle recursive pydantic models, please use a different adapter.")
101104

102-
_seen.add(model)
105+
# Add `pydantic_model` to `seen_models` with a placeholder value to avoid infinite recursion.
106+
seen_models.add(pydantic_model)
103107

104-
try:
105-
lines = []
106-
current_indent = " " * indent
107-
next_indent = " " * (indent + 1)
108-
109-
lines.append(f"{current_indent}{{")
110-
111-
fields = model.model_fields
112-
if not fields:
113-
lines.append(f"{next_indent}{COMMENT_SYMBOL} No fields defined")
114-
for name, field in fields.items():
115-
if field.description:
116-
lines.append(f"{next_indent}{COMMENT_SYMBOL} {field.description}")
117-
elif field.alias and field.alias != name:
118-
# If there's an alias but no description, show the alias as a comment
119-
lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}")
120-
121-
# Check for a nested Pydantic model
122-
field_type_to_render = field.annotation
123-
124-
# Unpack Optional[T] to get T
125-
origin = get_origin(field_type_to_render)
126-
if origin in (types.UnionType, Union):
127-
non_none_args = [arg for arg in get_args(field_type_to_render) if arg is not type(None)]
128-
if len(non_none_args) == 1:
129-
field_type_to_render = non_none_args[0]
130-
131-
# Unpack list[T] to get T
132-
origin = get_origin(field_type_to_render)
133-
if origin is list:
134-
field_type_to_render = get_args(field_type_to_render)[0]
135-
136-
if inspect.isclass(field_type_to_render) and issubclass(field_type_to_render, BaseModel):
137-
# Recursively build schema for nested models with circular reference protection
138-
nested_schema = _build_simplified_schema(field_type_to_render, indent + 1, _seen)
139-
rendered_type = _render_type_str(field.annotation, indent=indent + 1).replace(
140-
field_type_to_render.__name__, nested_schema
141-
)
142-
else:
143-
rendered_type = _render_type_str(field.annotation, indent=indent + 1)
144-
145-
line = f"{next_indent}{name}: {rendered_type},"
146-
147-
lines.append(line)
148-
149-
lines.append(f"{current_indent}}}")
150-
return "\n".join(lines)
151-
except Exception as e:
152-
return f"<error building schema for {model.__name__}: {e}>"
153-
finally:
154-
_seen.discard(model)
108+
lines = []
109+
current_indent = " " * indent
110+
next_indent = " " * (indent + 1)
111+
112+
lines.append(f"{current_indent}{{")
113+
114+
fields = pydantic_model.model_fields
115+
if not fields:
116+
lines.append(f"{next_indent}{COMMENT_SYMBOL} No fields defined")
117+
for name, field in fields.items():
118+
if field.description:
119+
lines.append(f"{next_indent}{COMMENT_SYMBOL} {field.description}")
120+
elif field.alias and field.alias != name:
121+
# If there's an alias but no description, show the alias as a comment
122+
lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}")
123+
124+
rendered_type = _render_type_str(field.annotation, indent=indent + 1, seen_models=seen_models)
125+
line = f"{next_indent}{name}: {rendered_type},"
126+
127+
lines.append(line)
128+
129+
lines.append(f"{current_indent}}}")
130+
return "\n".join(lines)
155131

156132

157133
class BAMLAdapter(JSONAdapter):
@@ -176,6 +152,7 @@ class PatientAddress(BaseModel):
176152
street: str
177153
city: str
178154
country: Literal["US", "CA"]
155+
179156
class PatientDetails(BaseModel):
180157
name: str = Field(description="Full name of the patient.")
181158
age: int
@@ -245,41 +222,14 @@ def format_field_structure(self, signature: type[Signature]) -> str:
245222
if signature.output_fields:
246223
for name, field in signature.output_fields.items():
247224
field_type = field.annotation
248-
main_type = field_type
249-
250-
# Find the core type if it's wrapped in Optional or Union
251-
origin = get_origin(field_type)
252-
if origin in (types.UnionType, Union):
253-
non_none_args = [arg for arg in get_args(field_type) if arg is not type(None)]
254-
if len(non_none_args) == 1:
255-
main_type = non_none_args[0]
256-
257225
sections.append(f"[[ ## {name} ## ]]")
258-
259-
if inspect.isclass(main_type) and issubclass(main_type, BaseModel):
260-
# We have a pydantic model, so build the simplified schema for it.
261-
schema_str = _build_simplified_schema(main_type)
262-
sections.append(schema_str)
263-
else:
264-
# Handle non-pydantic or primitive types simply
265-
type_str = _render_type_str(field_type, indent=0)
266-
sections.append(f"Output field `{name}` should be of type: {type_str}")
267-
268-
sections.append("") # Empty line after each output
226+
sections.append(f"Output field `{name}` should be of type: {_render_type_str(field_type, indent=0)}\n")
269227

270228
# Add completed section
271229
sections.append("[[ ## completed ## ]]")
272230

273231
return "\n".join(sections)
274232

275-
def format_task_description(self, signature: type[Signature]) -> str:
276-
"""Format the task description for the system message."""
277-
import textwrap
278-
279-
instructions = textwrap.dedent(signature.instructions)
280-
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
281-
return f"In adhering to this structure, your objective is: {objective}"
282-
283233
def format_user_message_content(
284234
self,
285235
signature: type[Signature],

tests/adapters/test_baml_adapter.py

Lines changed: 7 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,9 @@ class ImageWrapper(pydantic.BaseModel):
4040
tag: list[str]
4141

4242

43-
class CircularModelA(pydantic.BaseModel):
43+
class CircularModel(pydantic.BaseModel):
4444
name: str
45-
related_b: "CircularModelB | None" = None
46-
47-
48-
class CircularModelB(pydantic.BaseModel):
49-
id: int
50-
related_a: CircularModelA | None = None
45+
field: "CircularModel"
5146

5247

5348
def test_baml_adapter_basic_schema_generation():
@@ -143,20 +138,18 @@ class TestSignature(dspy.Signature):
143138
assert "metadata: dict[string, string]," in schema
144139

145140

146-
def test_baml_adapter_prevents_circular_references():
141+
def test_baml_adapter_raise_error_on_circular_references():
147142
"""Test that circular references are handled gracefully."""
148143

149144
class TestSignature(dspy.Signature):
150145
input: str = dspy.InputField()
151-
circular: CircularModelA = dspy.OutputField()
146+
circular: CircularModel = dspy.OutputField()
152147

153148
adapter = BAMLAdapter()
154-
schema = adapter.format_field_structure(TestSignature)
149+
with pytest.raises(ValueError) as error:
150+
adapter.format_field_structure(TestSignature)
155151

156-
# Should not cause infinite recursion and should handle ForwardRef gracefully
157-
assert "name: string," in schema
158-
# The actual output shows ForwardRef handling, which is acceptable
159-
assert "ForwardRef" in schema or "<circular reference" in schema
152+
assert "BAMLAdapter cannot handle recursive pydantic models" in str(error.value)
160153

161154

162155
def test_baml_adapter_formats_pydantic_inputs_as_clean_json():
@@ -227,53 +220,6 @@ class TestSignature(dspy.Signature):
227220
pass
228221

229222

230-
def test_baml_adapter_inherits_json_parsing_behavior():
231-
"""Test that BAMLAdapter maintains JSONAdapter's parsing compatibility."""
232-
233-
class TestSignature(dspy.Signature):
234-
question: str = dspy.InputField()
235-
answer: str = dspy.OutputField()
236-
237-
baml_adapter = BAMLAdapter()
238-
json_adapter = dspy.JSONAdapter()
239-
240-
# Both should parse the same JSON response identically
241-
completion = '{"answer": "Paris"}'
242-
243-
baml_result = baml_adapter.parse(TestSignature, completion)
244-
json_result = json_adapter.parse(TestSignature, completion)
245-
246-
assert baml_result == json_result == {"answer": "Paris"}
247-
248-
249-
def test_baml_adapter_parse_complex_pydantic_models():
250-
"""Test parsing JSON into complex Pydantic models."""
251-
252-
class TestSignature(dspy.Signature):
253-
input: str = dspy.InputField()
254-
patient: PatientDetails = dspy.OutputField()
255-
256-
adapter = BAMLAdapter()
257-
258-
completion = """{"patient": {
259-
"name": "John Smith",
260-
"age": 42,
261-
"address": {
262-
"street": "456 Oak St",
263-
"city": "Springfield",
264-
"country": "US"
265-
}
266-
}}"""
267-
268-
result = adapter.parse(TestSignature, completion)
269-
270-
assert isinstance(result["patient"], PatientDetails)
271-
assert result["patient"].name == "John Smith"
272-
assert result["patient"].age == 42
273-
assert isinstance(result["patient"].address, PatientAddress)
274-
assert result["patient"].address.street == "456 Oak St"
275-
276-
277223
def test_baml_adapter_raises_on_missing_fields():
278224
"""Test that missing required fields raise appropriate errors."""
279225

0 commit comments

Comments
 (0)