Skip to content

Commit 2f04758

Browse files
feat(models): use tool for litellm structured_output when supports_response_schema=false (#957)
1 parent 92da544 commit 2f04758

File tree

3 files changed

+136
-31
lines changed

3 files changed

+136
-31
lines changed

src/strands/models/litellm.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pydantic import BaseModel
1414
from typing_extensions import Unpack, override
1515

16+
from ..tools import convert_pydantic_to_tool_spec
1617
from ..types.content import ContentBlock, Messages
1718
from ..types.exceptions import ContextWindowOverflowException
1819
from ..types.streaming import StreamEvent
@@ -202,6 +203,10 @@ async def structured_output(
202203
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
203204
"""Get structured output from the model.
204205
206+
Some models do not support native structured output via response_format.
207+
In cases of proxies, we may not have a way to determine support, so we
208+
fallback to using tool calling to achieve structured output.
209+
205210
Args:
206211
output_model: The output model to use for the agent.
207212
prompt: The prompt messages to use for the agent.
@@ -211,42 +216,69 @@ async def structured_output(
211216
Yields:
212217
Model events with the last being the structured output.
213218
"""
214-
supports_schema = supports_response_schema(self.get_config()["model_id"])
219+
if supports_response_schema(self.get_config()["model_id"]):
220+
logger.debug("structuring output using response schema")
221+
result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt)
222+
else:
223+
logger.debug("model does not support response schema, structuring output using tool approach")
224+
result = await self._structured_output_using_tool(output_model, prompt, system_prompt)
225+
226+
yield {"output": result}
227+
228+
async def _structured_output_using_response_schema(
229+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None
230+
) -> T:
231+
"""Get structured output using native response_format support."""
232+
response = await litellm.acompletion(
233+
**self.client_args,
234+
model=self.get_config()["model_id"],
235+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
236+
response_format=output_model,
237+
)
215238

216-
# If the provider does not support response schemas, we cannot reliably parse structured output.
217-
# In that case we must not call the provider and must raise the documented ValueError.
218-
if not supports_schema:
219-
raise ValueError("Model does not support response_format")
239+
if len(response.choices) > 1:
240+
raise ValueError("Multiple choices found in the response.")
241+
if not response.choices or response.choices[0].finish_reason != "tool_calls":
242+
raise ValueError("No tool_calls found in response")
220243

221-
# For providers that DO support response schemas, call litellm and map context-window errors.
244+
choice = response.choices[0]
222245
try:
223-
response = await litellm.acompletion(
224-
**self.client_args,
225-
model=self.get_config()["model_id"],
226-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
227-
response_format=output_model,
228-
)
246+
# Parse the message content as JSON
247+
tool_call_data = json.loads(choice.message.content)
248+
# Instantiate the output model with the parsed data
249+
return output_model(**tool_call_data)
229250
except ContextWindowExceededError as e:
230251
logger.warning("litellm client raised context window overflow in structured_output")
231252
raise ContextWindowOverflowException(e) from e
253+
except (json.JSONDecodeError, TypeError, ValueError) as e:
254+
raise ValueError(f"Failed to parse or load content into model: {e}") from e
255+
256+
async def _structured_output_using_tool(
257+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None
258+
) -> T:
259+
"""Get structured output using tool calling fallback."""
260+
tool_spec = convert_pydantic_to_tool_spec(output_model)
261+
request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}}))
262+
args = {**self.client_args, **request, "stream": False}
263+
response = await litellm.acompletion(**args)
232264

233265
if len(response.choices) > 1:
234266
raise ValueError("Multiple choices found in the response.")
267+
if not response.choices or response.choices[0].finish_reason != "tool_calls":
268+
raise ValueError("No tool_calls found in response")
235269

236-
# Find the first choice with tool_calls
237-
for choice in response.choices:
238-
if choice.finish_reason == "tool_calls":
239-
try:
240-
# Parse the tool call content as JSON
241-
tool_call_data = json.loads(choice.message.content)
242-
# Instantiate the output model with the parsed data
243-
yield {"output": output_model(**tool_call_data)}
244-
return
245-
except (json.JSONDecodeError, TypeError, ValueError) as e:
246-
raise ValueError(f"Failed to parse or load content into model: {e}") from e
247-
248-
# If no tool_calls found, raise an error
249-
raise ValueError("No tool_calls found in response")
270+
choice = response.choices[0]
271+
try:
272+
# Parse the tool call content as JSON
273+
tool_call = choice.message.tool_calls[0]
274+
tool_call_data = json.loads(tool_call.function.arguments)
275+
# Instantiate the output model with the parsed data
276+
return output_model(**tool_call_data)
277+
except ContextWindowExceededError as e:
278+
logger.warning("litellm client raised context window overflow in structured_output")
279+
raise ContextWindowOverflowException(e) from e
280+
except (json.JSONDecodeError, TypeError, ValueError) as e:
281+
raise ValueError(f"Failed to parse or load content into model: {e}") from e
250282

251283
def _apply_proxy_prefix(self) -> None:
252284
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.

tests/strands/models/test_litellm.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c
292292

293293

294294
@pytest.mark.asyncio
295-
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls):
295+
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist):
296296
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
297297

298+
mock_tool_call = unittest.mock.Mock()
299+
mock_tool_call.function.arguments = '{"name": "John", "age": 30}'
300+
301+
mock_choice = unittest.mock.Mock()
302+
mock_choice.finish_reason = "tool_calls"
303+
mock_choice.message.tool_calls = [mock_tool_call]
304+
mock_response = unittest.mock.Mock()
305+
mock_response.choices = [mock_choice]
306+
307+
litellm_acompletion.return_value = mock_response
308+
298309
with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False):
299-
with pytest.raises(ValueError, match="Model does not support response_format"):
300-
stream = model.structured_output(test_output_model_cls, messages)
301-
await stream.__anext__()
310+
stream = model.structured_output(test_output_model_cls, messages)
311+
events = await alist(stream)
312+
tru_result = events[-1]
302313

303-
litellm_acompletion.assert_not_called()
314+
exp_result = {"output": test_output_model_cls(name="John", age=30)}
315+
assert tru_result == exp_result
304316

305317

306318
def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):

tests_integ/models/test_model_litellm.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import unittest.mock
2+
13
import pydantic
24
import pytest
35

@@ -40,6 +42,37 @@ class Weather(pydantic.BaseModel):
4042
return Weather(time="12:00", weather="sunny")
4143

4244

45+
class Location(pydantic.BaseModel):
46+
"""Location information."""
47+
48+
city: str = pydantic.Field(description="The city name")
49+
country: str = pydantic.Field(description="The country name")
50+
51+
52+
class WeatherCondition(pydantic.BaseModel):
53+
"""Weather condition details."""
54+
55+
condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')")
56+
temperature: int = pydantic.Field(description="Temperature in Celsius")
57+
58+
59+
class NestedWeather(pydantic.BaseModel):
60+
"""Weather report with nested location and condition information."""
61+
62+
time: str = pydantic.Field(description="The time in HH:MM format")
63+
location: Location = pydantic.Field(description="Location information")
64+
weather: WeatherCondition = pydantic.Field(description="Weather condition details")
65+
66+
67+
@pytest.fixture
68+
def nested_weather():
69+
return NestedWeather(
70+
time="12:00",
71+
location=Location(city="New York", country="USA"),
72+
weather=WeatherCondition(condition="sunny", temperature=25),
73+
)
74+
75+
4376
@pytest.fixture
4477
def yellow_color():
4578
class Color(pydantic.BaseModel):
@@ -134,3 +167,31 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
134167
tru_color = agent.structured_output(type(yellow_color), content)
135168
exp_color = yellow_color
136169
assert tru_color == exp_color
170+
171+
172+
def test_structured_output_unsupported_model(model, nested_weather):
173+
# Mock supports_response_schema to return False to test fallback mechanism
174+
with (
175+
unittest.mock.patch.multiple(
176+
"strands.models.litellm",
177+
supports_response_schema=unittest.mock.DEFAULT,
178+
) as mocks,
179+
unittest.mock.patch.object(
180+
model, "_structured_output_using_tool", wraps=model._structured_output_using_tool
181+
) as mock_tool,
182+
unittest.mock.patch.object(
183+
model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema
184+
) as mock_schema,
185+
):
186+
mocks["supports_response_schema"].return_value = False
187+
188+
# Test that structured output still works via tool calling fallback
189+
agent = Agent(model=model)
190+
prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius"
191+
tru_weather = agent.structured_output(NestedWeather, prompt)
192+
exp_weather = nested_weather
193+
assert tru_weather == exp_weather
194+
195+
# Verify that the tool method was called and schema method was not
196+
mock_tool.assert_called_once()
197+
mock_schema.assert_not_called()

0 commit comments

Comments
 (0)