-
Notifications
You must be signed in to change notification settings - Fork 425
feat: Add Structured Output as part of the agent loop #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
dd86a56
dceb617
75a0ce7
b543cec
099e70a
882fcd6
7f2d73e
4979771
b9f9456
36bd507
dba8828
fb274ac
dcc6ac4
45dd56b
7ad09b2
9003fdd
7b192a1
89ea3c6
247e9c4
eeb97be
8f5ffad
a42171c
cfb39f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,22 @@ | ||
"""A framework for building, deploying, and managing AI agents.""" | ||
|
||
from . import agent, models, telemetry, types | ||
from . import agent, models, output, telemetry, types | ||
from .agent.agent import Agent | ||
from .output import NativeMode, OutputSchema, PromptMode, ToolMode | ||
from .tools.decorator import tool | ||
from .types.tools import ToolContext | ||
|
||
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"] | ||
__all__ = [ | ||
"Agent", | ||
"agent", | ||
"models", | ||
"output", | ||
"NativeMode", | ||
"OutputSchema", | ||
"PromptMode", | ||
"tool", | ||
"ToolContext", | ||
"ToolMode", | ||
"types", | ||
"telemetry", | ||
] |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -26,7 +26,7 @@ | |||
Union, | ||||
cast, | ||||
) | ||||
|
||||
from typing_extensions import deprecated | ||||
from opentelemetry import trace as trace_api | ||||
from pydantic import BaseModel | ||||
|
||||
|
@@ -43,6 +43,9 @@ | |||
) | ||||
from ..models.bedrock import BedrockModel | ||||
from ..models.model import Model | ||||
from ..output.base import OutputSchema | ||||
from ..output.modes import ToolMode | ||||
from ..output.utils import resolve_output_schema | ||||
from ..session.session_manager import SessionManager | ||||
from ..telemetry.metrics import EventLoopMetrics | ||||
from ..telemetry.tracer import get_tracer, serialize | ||||
|
@@ -210,6 +213,7 @@ def __init__( | |||
messages: Optional[Messages] = None, | ||||
tools: Optional[list[Union[str, dict[str, str], Any]]] = None, | ||||
system_prompt: Optional[str] = None, | ||||
structured_output_type: Optional[Type[BaseModel]] = None, | ||||
zastrowm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
callback_handler: Optional[ | ||||
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] | ||||
] = _DEFAULT_CALLBACK_HANDLER, | ||||
|
@@ -245,6 +249,10 @@ def __init__( | |||
If provided, only these tools will be available. If None, all tools will be available. | ||||
system_prompt: System prompt to guide model behavior. | ||||
If None, the model will behave according to its default settings. | ||||
structured_output_type: Pydantic model type(s) for structured output. | ||||
When specified, all agent calls will attempt to return structured output of this type. | ||||
This can be overridden on the agent invocation. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think this goes against convention for the agent class. This should either be a class attribute, or an kwargument on the invoke method. Im inclined to lean toward just the invoke kwargument. |
||||
Defaults to None (no structured output). | ||||
callback_handler: Callback for processing events as they happen during agent execution. | ||||
If not provided (using the default), a new PrintingCallbackHandler instance is created. | ||||
If explicitly set to None, null_callback_handler is used. | ||||
|
@@ -274,8 +282,8 @@ def __init__( | |||
""" | ||||
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model | ||||
self.messages = messages if messages is not None else [] | ||||
|
||||
self.system_prompt = system_prompt | ||||
self.default_output_schema = resolve_output_schema(structured_output_type) | ||||
zastrowm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) | ||||
self.name = name or _DEFAULT_AGENT_NAME | ||||
self.description = description | ||||
|
@@ -374,7 +382,9 @@ def tool_names(self) -> list[str]: | |||
all_tools = self.tool_registry.get_all_tools_config() | ||||
return list(all_tools.keys()) | ||||
|
||||
def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: | ||||
def __call__( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: pretty sure this has updated, so you will need to rebase |
||||
self, prompt: AgentInput = None, structured_output_type: Optional[Type[BaseModel]] = None, **kwargs: Any | ||||
afarntrog marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
) -> AgentResult: | ||||
"""Process a natural language prompt through the agent's event loop. | ||||
|
||||
This method implements the conversational interface with multiple input patterns: | ||||
|
@@ -389,6 +399,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: | |||
- list[ContentBlock]: Multi-modal content blocks | ||||
- list[Message]: Complete messages with roles | ||||
- None: Use existing conversation history | ||||
structured_output_type: Pydantic model type(s) for structured output (overrides agent default). | ||||
**kwargs: Additional parameters to pass through the event loop. | ||||
|
||||
Returns: | ||||
|
@@ -398,16 +409,19 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: | |||
- message: The final message from the model | ||||
- metrics: Performance metrics from the event loop | ||||
- state: The final state of the event loop | ||||
- structured_output: Parsed structured output when structured_output_type was specified | ||||
""" | ||||
|
||||
def execute() -> AgentResult: | ||||
return asyncio.run(self.invoke_async(prompt, **kwargs)) | ||||
return asyncio.run(self.invoke_async(prompt, structured_output_type, **kwargs)) | ||||
|
||||
with ThreadPoolExecutor() as executor: | ||||
future = executor.submit(execute) | ||||
return future.result() | ||||
|
||||
async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: | ||||
async def invoke_async( | ||||
self, prompt: AgentInput = None, structured_output_type: Optional[Type[BaseModel]] = None, **kwargs: Any | ||||
afarntrog marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
) -> AgentResult: | ||||
"""Process a natural language prompt through the agent's event loop. | ||||
|
||||
This method implements the conversational interface with multiple input patterns: | ||||
|
@@ -422,6 +436,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR | |||
- list[ContentBlock]: Multi-modal content blocks | ||||
- list[Message]: Complete messages with roles | ||||
- None: Use existing conversation history | ||||
structured_output_type: Pydantic model type(s) for structured output (overrides agent default). | ||||
**kwargs: Additional parameters to pass through the event loop. | ||||
|
||||
Returns: | ||||
|
@@ -432,12 +447,17 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR | |||
- metrics: Performance metrics from the event loop | ||||
- state: The final state of the event loop | ||||
""" | ||||
events = self.stream_async(prompt, **kwargs) | ||||
events = self.stream_async(prompt, structured_output_type=structured_output_type, **kwargs) | ||||
async for event in events: | ||||
_ = event | ||||
|
||||
return cast(AgentResult, event["result"]) | ||||
|
||||
@deprecated( | ||||
"Agent.structured_output method is deprecated." | ||||
" You should pass in `structured_output_type` directly into the agent invocation." | ||||
" see the <LINK> for more details" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO - update LINK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sdk-python/src/strands/tools/loader.py Line 160 in 7fbc9dc
|
||||
) | ||||
def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: | ||||
"""This method allows you to get structured output from the agent. | ||||
|
||||
|
@@ -467,6 +487,11 @@ def execute() -> T: | |||
future = executor.submit(execute) | ||||
return future.result() | ||||
|
||||
@deprecated( | ||||
"Agent.structured_output_async method is deprecated." | ||||
" You should pass in `structured_output_type` directly into the agent invocation." | ||||
" see the <LINK> for more details" | ||||
) | ||||
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: | ||||
"""This method allows you to get structured output from the agent. | ||||
|
||||
|
@@ -530,6 +555,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu | |||
async def stream_async( | ||||
self, | ||||
prompt: AgentInput = None, | ||||
structured_output_type: Optional[Type[BaseModel]] = None, | ||||
**kwargs: Any, | ||||
) -> AsyncIterator[Any]: | ||||
"""Process a natural language prompt and yield events as an async iterator. | ||||
|
@@ -546,6 +572,7 @@ async def stream_async( | |||
- list[ContentBlock]: Multi-modal content blocks | ||||
- list[Message]: Complete messages with roles | ||||
- None: Use existing conversation history | ||||
structured_output_type: Pydantic model type(s) for structured output (overrides agent default). | ||||
**kwargs: Additional parameters to pass to the event loop. | ||||
|
||||
Yields: | ||||
|
@@ -569,14 +596,19 @@ async def stream_async( | |||
""" | ||||
callback_handler = kwargs.get("callback_handler", self.callback_handler) | ||||
|
||||
# runtime override or agent default TODO in the future, when we expose 'output_schema, we should consider allowing for halfway configuration. for example, the user should be able to define `output_mode` on the Agent level but `structured_output_type` on the `output_mode` | ||||
zastrowm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
output_schema: Optional[OutputSchema] = ( | ||||
resolve_output_schema(structured_output_type) or self.default_output_schema | ||||
) | ||||
|
||||
# Process input and get message to add (if any) | ||||
messages = self._convert_prompt_to_messages(prompt) | ||||
|
||||
self.trace_span = self._start_agent_trace_span(messages) | ||||
|
||||
with trace_api.use_span(self.trace_span): | ||||
try: | ||||
events = self._run_loop(messages, invocation_state=kwargs) | ||||
events = self._run_loop(messages, kwargs, output_schema=output_schema) | ||||
|
||||
async for event in events: | ||||
event.prepare(invocation_state=kwargs) | ||||
|
@@ -596,7 +628,9 @@ async def stream_async( | |||
self._end_agent_trace_span(error=e) | ||||
raise | ||||
|
||||
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: | ||||
async def _run_loop( | ||||
self, messages: Messages, invocation_state: dict[str, Any], output_schema: Optional[OutputSchema] = None | ||||
) -> AsyncGenerator[TypedEvent, None]: | ||||
"""Execute the agent's event loop with the given message and parameters. | ||||
|
||||
Args: | ||||
|
@@ -614,6 +648,8 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) | |||
for message in messages: | ||||
self._append_message(message) | ||||
|
||||
invocation_state["output_schema"] = output_schema | ||||
zastrowm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
|
||||
# Execute the event loop cycle with retry logic for context limits | ||||
events = self._execute_event_loop_cycle(invocation_state) | ||||
async for event in events: | ||||
|
@@ -648,6 +684,11 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A | |||
""" | ||||
# Add `Agent` to invocation_state to keep backwards-compatibility | ||||
invocation_state["agent"] = self | ||||
output_schema: OutputSchema = invocation_state.get("output_schema") | ||||
|
||||
if output_schema and isinstance(output_schema.mode, ToolMode): | ||||
for tool_instance in output_schema.mode.get_tool_instances(output_schema.type): | ||||
self.tool_registry.register_dynamic_tool(tool_instance) | ||||
|
||||
try: | ||||
# Execute the main event loop cycle | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any | ||
from typing import Any, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import Message | ||
|
@@ -20,12 +22,14 @@ class AgentResult: | |
message: The last message generated by the agent. | ||
metrics: Performance metrics collected during processing. | ||
state: Additional state information from the event loop. | ||
structured_output: Parsed structured output when structured_output_type was specified. | ||
""" | ||
|
||
stop_reason: StopReason | ||
message: Message | ||
metrics: EventLoopMetrics | ||
state: Any | ||
structured_output: Optional[BaseModel] = None | ||
|
||
|
||
def __str__(self) -> str: | ||
"""Get the agent's last message as a string. | ||
|
Uh oh!
There was an error while loading. Please reload this page.