diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..267b890d3 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -12,6 +12,7 @@ import llama_api_client from llama_api_client import LlamaAPIClient +from llama_api_client.types.chat.completion_create_params import ResponseFormat from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override @@ -406,7 +407,7 @@ async def stream( logger.debug("finished streaming response from model") @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. @@ -419,20 +420,30 @@ def structured_output( Yields: Model events with the last being the structured output. - - Raises: - NotImplementedError: Structured output is not currently supported for LlamaAPI models. """ - # response_format: ResponseFormat = { - # "type": "json_schema", - # "json_schema": { - # "name": output_model.__name__, - # "schema": output_model.model_json_schema(), - # }, - # } - # response = self.client.chat.completions.create( - # model=self.config["model_id"], - # messages=self.format_request(prompt)["messages"], - # response_format=response_format, - # ) - raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.") + response_format: ResponseFormat = { + "type": "json_schema", + "json_schema": { + "name": output_model.__name__, + "schema": output_model.model_json_schema(), + }, + } + + try: + response = self.client.chat.completions.create( + model=self.config["model_id"], + messages=self.format_request(prompt)["messages"], + response_format=response_format, + ) + + content = response.completion_message.content + if content is None: + raise ValueError("No content found in Llama API response") + elif not isinstance(content, str): + content = content.text + + output_response = json.loads(content) + yield {"output": output_model(**output_response)} + + except Exception as e: + raise ValueError(f"Llama API structured output error: {e}") from e diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 309dac2e9..fc9277a33 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import unittest.mock +import pydantic import pytest import strands @@ -35,6 +36,15 @@ def system_prompt(): return "s1" +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + def test__init__model_configs(llamaapi_client, model_id): _ = llamaapi_client @@ -361,3 +371,16 @@ def test_format_chunk_other(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.asyncio +async def test_structured_output(llamaapi_client, model, messages, test_output_model_cls, alist): + mock_api_response = unittest.mock.Mock() + mock_api_response.completion_message.content.text = '{"name": "John", "age": 30}' + + llamaapi_client.chat.completions.create = unittest.mock.Mock(return_value=mock_api_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + assert events[-1] == {"output": test_output_model_cls(name="John", age=30)} diff --git a/tests_integ/models/test_model_llamaapi.py b/tests_integ/models/test_model_llamaapi.py index b36a63a28..1f1a3771c 100644 --- a/tests_integ/models/test_model_llamaapi.py +++ b/tests_integ/models/test_model_llamaapi.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import os +import pydantic import pytest import strands @@ -40,8 +41,32 @@ def agent(model, tools): return Agent(model=model, tools=tools) +@pytest.fixture +def weather(): + class Weather(pydantic.BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + def test_agent(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather