Skip to content

feat(llamaapi): enable structured_output support #724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import llama_api_client
from llama_api_client import LlamaAPIClient
from llama_api_client.types.chat.completion_create_params import ResponseFormat
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

Expand Down Expand Up @@ -406,7 +407,7 @@ async def stream(
logger.debug("finished streaming response from model")

@override
def structured_output(
async def structured_output(
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.
Expand All @@ -419,20 +420,30 @@ def structured_output(

Yields:
Model events with the last being the structured output.

Raises:
NotImplementedError: Structured output is not currently supported for LlamaAPI models.
"""
# response_format: ResponseFormat = {
# "type": "json_schema",
# "json_schema": {
# "name": output_model.__name__,
# "schema": output_model.model_json_schema(),
# },
# }
# response = self.client.chat.completions.create(
# model=self.config["model_id"],
# messages=self.format_request(prompt)["messages"],
# response_format=response_format,
# )
raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.")
response_format: ResponseFormat = {
"type": "json_schema",
"json_schema": {
"name": output_model.__name__,
"schema": output_model.model_json_schema(),
},
}

try:
response = self.client.chat.completions.create(
model=self.config["model_id"],
messages=self.format_request(prompt)["messages"],
response_format=response_format,
)

content = response.completion_message.content
if content is None:
raise ValueError("No content found in Llama API response")
elif not isinstance(content, str):
content = content.text

output_response = json.loads(content)
yield {"output": output_model(**output_response)}

except Exception as e:
raise ValueError(f"Llama API structured output error: {e}") from e
23 changes: 23 additions & 0 deletions tests/strands/models/test_llamaapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import unittest.mock

import pydantic
import pytest

import strands
Expand Down Expand Up @@ -35,6 +36,15 @@ def system_prompt():
return "s1"


@pytest.fixture
def test_output_model_cls():
class TestOutputModel(pydantic.BaseModel):
name: str
age: int

return TestOutputModel


def test__init__model_configs(llamaapi_client, model_id):
_ = llamaapi_client

Expand Down Expand Up @@ -361,3 +371,16 @@ def test_format_chunk_other(model):

with pytest.raises(RuntimeError, match="chunk_type=<other> | unknown type"):
model.format_chunk(event)


@pytest.mark.asyncio
async def test_structured_output(llamaapi_client, model, messages, test_output_model_cls, alist):
mock_api_response = unittest.mock.Mock()
mock_api_response.completion_message.content.text = '{"name": "John", "age": 30}'

llamaapi_client.chat.completions.create = unittest.mock.Mock(return_value=mock_api_response)

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)

assert events[-1] == {"output": test_output_model_cls(name="John", age=30)}
25 changes: 25 additions & 0 deletions tests_integ/models/test_model_llamaapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import os

import pydantic
import pytest

import strands
Expand Down Expand Up @@ -40,8 +41,32 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.fixture
def weather():
class Weather(pydantic.BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

return Weather(time="12:00", weather="sunny")


def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


def test_agent_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather