-
Notifications
You must be signed in to change notification settings - Fork 314
feat(agent): make structured output part of the agent loop #670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,17 +10,20 @@ | |
""" | ||
|
||
import asyncio | ||
import copy | ||
import json | ||
import logging | ||
import random | ||
from concurrent.futures import ThreadPoolExecutor | ||
from contextlib import suppress | ||
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast | ||
|
||
from opentelemetry import trace as trace_api | ||
from pydantic import BaseModel | ||
|
||
from .. import _identifier | ||
from ..event_loop.event_loop import event_loop_cycle, run_tool | ||
from ..experimental.hooks import AfterToolInvocationEvent | ||
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler | ||
from ..hooks import ( | ||
AfterInvocationEvent, | ||
|
@@ -34,7 +37,8 @@ | |
from ..models.model import Model | ||
from ..session.session_manager import SessionManager | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..telemetry.tracer import get_tracer, serialize | ||
from ..telemetry.tracer import get_tracer | ||
from ..tools.decorator import tool | ||
from ..tools.registry import ToolRegistry | ||
from ..tools.watcher import ToolWatcher | ||
from ..types.content import ContentBlock, Message, Messages | ||
|
@@ -404,7 +408,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A | |
|
||
return cast(AgentResult, event["result"]) | ||
|
||
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: | ||
def structured_output( | ||
self, | ||
output_model: Type[T], | ||
prompt: Optional[Union[str, list[ContentBlock]]] = None, | ||
preserve_conversation: bool = False, | ||
) -> T: | ||
"""This method allows you to get structured output from the agent. | ||
|
||
If you pass in a prompt, it will be used temporarily without adding it to the conversation history. | ||
|
@@ -417,20 +426,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l | |
output_model: The output model (a JSON schema written as a Pydantic BaseModel) | ||
that the agent will use when responding. | ||
prompt: The prompt to use for the agent (will not be added to conversation history). | ||
preserve_conversation: If False (default), restores original conversation state after execution. | ||
If True, allows structured output execution to modify conversation history. | ||
|
||
Raises: | ||
ValueError: If no conversation history or prompt is provided. | ||
""" | ||
|
||
def execute() -> T: | ||
return asyncio.run(self.structured_output_async(output_model, prompt)) | ||
return asyncio.run(self.structured_output_async(output_model, prompt, preserve_conversation)) | ||
|
||
with ThreadPoolExecutor() as executor: | ||
future = executor.submit(execute) | ||
return future.result() | ||
|
||
def _register_structured_output_tool(self, output_model: type[BaseModel]) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's unclear to me why this tool is needed. the current structured output implementation within strands does utilize a tool to generate structured output (which is subject to change). however, i am hesitant to retry the entire agent loop in the pattern suggested in this PR. another way to implement this is returning the schema validation failures as tool failures to the model. structured output is currently an open roadmap item that we are planning to redesign. i am not comfortable merging this PR in it's current state given that the underlying implementation is subject to change and due to the points raised above. currently, the approach of using a tool to generate structured output is an anti-pattern since most model providers natively support structured output response. these native approaches improve structure output performance substantially. reference:
given these natively supported features, retries can be simply implemented on the actual model API request, a full retry of the agent loop is not required. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JSON output and TOOL use at the model level is likely an identical implementation to force the tokens into the json schema space. The performance of the models greatly improves in the agent loop context as it will have the context to evaluate its own work. Please see anthropic docs This is an additional feature on top of native support. Think of it as the tool/json output at the model level forces the model to write in cursive, but doesn't actually check the contents of the writing. With a pydantic model you can have additional logical validations checking the content. Please explain why is it an antipattern to use a tool for structured output? It re-suses the execution guardrails testing of tool execution where failures are fed back to the model. You would be reimplementing the same logic at the model structured output function level if you were to try to feedback any validation errors. Under this setup the agent has the chance of performing all the actions as under the run command, and responding with a structured output, which is very useful for interacting with the output programatically.
How is what you describe here different to rerunning the agent loop? if you feedback the schema failure as a tool result and then the agent responds that is effectively an agent loop with extra steps. |
||
@tool | ||
def _structured_output(input: output_model) -> output_model: # type: ignore[valid-type] | ||
"""If this tool is present it MUST be used to return structured data for the user.""" | ||
return input | ||
|
||
return _structured_output | ||
|
||
async def structured_output_async( | ||
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None | ||
self, | ||
output_model: Type[T], | ||
prompt: Optional[Union[str, list[ContentBlock]]] = None, | ||
preserve_conversation: bool = False, | ||
) -> T: | ||
"""This method allows you to get structured output from the agent. | ||
|
||
|
@@ -444,53 +466,141 @@ async def structured_output_async( | |
output_model: The output model (a JSON schema written as a Pydantic BaseModel) | ||
that the agent will use when responding. | ||
prompt: The prompt to use for the agent (will not be added to conversation history). | ||
preserve_conversation: If False (default), restores original conversation state after execution. | ||
If True, allows structured output execution to modify conversation history. | ||
|
||
Raises: | ||
ValueError: If no conversation history or prompt is provided. | ||
""" | ||
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) | ||
with self.tracer.tracer.start_as_current_span( | ||
"execute_structured_output", kind=trace_api.SpanKind.CLIENT | ||
) as structured_output_span: | ||
try: | ||
|
||
# Store references to what we'll add temporarily | ||
added_tool_name = None | ||
added_callback = None | ||
|
||
# Save original messages if we need to restore them later | ||
original_messages = copy.deepcopy(self.messages) if not preserve_conversation else None | ||
|
||
# Create and add the structured output tool | ||
structured_output_tool = self._register_structured_output_tool(output_model) | ||
self.tool_registry.register_tool(structured_output_tool) | ||
added_tool_name = structured_output_tool.tool_name | ||
|
||
# Variable to capture the structured result | ||
captured_result = None | ||
|
||
# Hook to capture structured output tool invocation | ||
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> None: | ||
nonlocal captured_result | ||
|
||
if ( | ||
event.selected_tool | ||
and hasattr(event.selected_tool, "tool_name") | ||
and event.selected_tool.tool_name == "_structured_output" | ||
and event.result | ||
and event.result.get("status") == "success" | ||
): | ||
# Parse the validated Pydantic model from the tool result | ||
with suppress(Exception): | ||
content = event.result.get("content", []) | ||
if content and isinstance(content[0], dict) and "text" in content[0]: | ||
# The tool returns the model instance as string, but we need the actual instance | ||
# Since our tool returns the input directly, we can reconstruct it | ||
tool_input = event.tool_use.get("input", {}).get("input") | ||
if tool_input: | ||
captured_result = output_model(**tool_input) | ||
|
||
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook) | ||
added_callback = capture_structured_output_hook | ||
|
||
# Create message for tracing | ||
message: Message | ||
if prompt: | ||
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt | ||
message = {"role": "user", "content": content} | ||
else: | ||
# Use existing conversation history | ||
message = { | ||
"role": "user", | ||
"content": [ | ||
{"text": "Please provide the information from our conversation in the requested structured format."} | ||
], | ||
} | ||
|
||
# Start agent trace span (same as stream_async) | ||
self.trace_span = self._start_agent_trace_span(message) | ||
|
||
try: | ||
with trace_api.use_span(self.trace_span): | ||
if not self.messages and not prompt: | ||
raise ValueError("No conversation history or prompt provided") | ||
# Create temporary messages array if prompt is provided | ||
if prompt: | ||
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt | ||
temp_messages = self.messages + [{"role": "user", "content": content}] | ||
else: | ||
temp_messages = self.messages | ||
|
||
structured_output_span.set_attributes( | ||
{ | ||
"gen_ai.system": "strands-agents", | ||
"gen_ai.agent.name": self.name, | ||
"gen_ai.agent.id": self.agent_id, | ||
"gen_ai.operation.name": "execute_structured_output", | ||
} | ||
) | ||
for message in temp_messages: | ||
structured_output_span.add_event( | ||
f"gen_ai.{message['role']}.message", | ||
attributes={"role": message["role"], "content": serialize(message["content"])}, | ||
) | ||
if self.system_prompt: | ||
structured_output_span.add_event( | ||
"gen_ai.system.message", | ||
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, | ||
|
||
invocation_state = { | ||
"structured_output_mode": True, | ||
"structured_output_model": output_model, | ||
} | ||
|
||
# Run the event loop | ||
async for event in self._run_loop(message=message, invocation_state=invocation_state): | ||
if "stop" in event: | ||
break | ||
|
||
# Return the captured structured result if we got it from the tool | ||
if captured_result: | ||
self._end_agent_trace_span( | ||
response=AgentResult( | ||
message={"role": "assistant", "content": [{"text": str(captured_result)}]}, | ||
stop_reason="end_turn", | ||
metrics=self.event_loop_metrics, | ||
state={}, | ||
) | ||
) | ||
return captured_result | ||
|
||
# Fallback: Use the original model.structured_output approach | ||
# This maintains backward compatibility with existing tests and implementations | ||
# Use original_messages to get clean message state, or self.messages if preserve_conversation=True | ||
base_messages = original_messages if original_messages is not None else self.messages | ||
temp_messages = base_messages if not prompt else base_messages + [message] | ||
|
||
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) | ||
async for event in events: | ||
if "callback" in event: | ||
self.callback_handler(**cast(dict, event["callback"])) | ||
structured_output_span.add_event( | ||
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i am concerned that the tracing behavior has changed and hasn't been manually or unit tested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok we can assert it further |
||
self._end_agent_trace_span( | ||
response=AgentResult( | ||
message={"role": "assistant", "content": [{"text": str(event["output"])}]}, | ||
stop_reason="end_turn", | ||
metrics=self.event_loop_metrics, | ||
state={}, | ||
) | ||
) | ||
return event["output"] | ||
return cast(T, event["output"]) | ||
|
||
finally: | ||
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) | ||
except Exception as e: | ||
self._end_agent_trace_span(error=e) | ||
raise | ||
finally: | ||
# Clean up what we added - remove the callback | ||
if added_callback is not None: | ||
with suppress(ValueError, KeyError): | ||
callbacks = self.hooks._registered_callbacks.get(AfterToolInvocationEvent, []) | ||
if added_callback in callbacks: | ||
callbacks.remove(added_callback) | ||
|
||
# Remove the tool we added | ||
if added_tool_name: | ||
if added_tool_name in self.tool_registry.registry: | ||
del self.tool_registry.registry[added_tool_name] | ||
if added_tool_name in self.tool_registry.dynamic_tools: | ||
del self.tool_registry.dynamic_tools[added_tool_name] | ||
|
||
# Restore original messages if preserve_conversation is False | ||
if original_messages is not None: | ||
self.messages = original_messages | ||
|
||
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) | ||
|
||
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: | ||
"""Process a natural language prompt and yield events as an async iterator. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -992,13 +992,12 @@ def test_agent_callback_handler_custom_handler_used(): | |
|
||
|
||
def test_agent_structured_output(agent, system_prompt, user, agenerator): | ||
# Setup mock tracer and span | ||
mock_strands_tracer = unittest.mock.MagicMock() | ||
mock_otel_tracer = unittest.mock.MagicMock() | ||
# Mock the agent tracing methods instead of direct OpenTelemetry calls | ||
agent._start_agent_trace_span = unittest.mock.Mock() | ||
agent._end_agent_trace_span = unittest.mock.Mock() | ||
mock_span = unittest.mock.MagicMock() | ||
mock_strands_tracer.tracer = mock_otel_tracer | ||
mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span | ||
agent.tracer = mock_strands_tracer | ||
agent._start_agent_trace_span.return_value = mock_span | ||
agent.trace_span = mock_span | ||
|
||
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) | ||
|
||
|
@@ -1019,34 +1018,19 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): | |
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt | ||
) | ||
|
||
mock_span.set_attributes.assert_called_once_with( | ||
{ | ||
"gen_ai.system": "strands-agents", | ||
"gen_ai.agent.name": "Strands Agents", | ||
"gen_ai.agent.id": "default", | ||
"gen_ai.operation.name": "execute_structured_output", | ||
} | ||
) | ||
|
||
mock_span.add_event.assert_any_call( | ||
"gen_ai.user.message", | ||
attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is [email protected]"}]'}, | ||
) | ||
|
||
mock_span.add_event.assert_called_with( | ||
"gen_ai.choice", | ||
attributes={"message": json.dumps(user.model_dump())}, | ||
) | ||
# Verify agent-level tracing was called | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the assertions have been weakened here. it is preferred to keep a high bar on assertions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what assertion would you like to call here? |
||
agent._start_agent_trace_span.assert_called_once() | ||
agent._end_agent_trace_span.assert_called_once() | ||
|
||
|
||
def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): | ||
# Setup mock tracer and span | ||
mock_strands_tracer = unittest.mock.MagicMock() | ||
mock_otel_tracer = unittest.mock.MagicMock() | ||
# Mock the agent tracing methods instead of direct OpenTelemetry calls | ||
agent._start_agent_trace_span = unittest.mock.Mock() | ||
agent._end_agent_trace_span = unittest.mock.Mock() | ||
mock_span = unittest.mock.MagicMock() | ||
mock_strands_tracer.tracer = mock_otel_tracer | ||
mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span | ||
agent.tracer = mock_strands_tracer | ||
agent._start_agent_trace_span.return_value = mock_span | ||
agent.trace_span = mock_span | ||
|
||
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) | ||
|
||
prompt = [ | ||
|
@@ -1076,10 +1060,9 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a | |
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt | ||
) | ||
|
||
mock_span.add_event.assert_called_with( | ||
"gen_ai.choice", | ||
attributes={"message": json.dumps(user.model_dump())}, | ||
) | ||
# Verify agent-level tracing was called | ||
agent._start_agent_trace_span.assert_called_once() | ||
agent._end_agent_trace_span.assert_called_once() | ||
|
||
|
||
@pytest.mark.asyncio | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the existence (or lack thereof) of the prompt parameter determines this behavior, so i'm not sure we want to modify this interface and add another parameter here.