Skip to content

Add typed events #716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ writer = [
sagemaker = [
"boto3>=1.26.0,<2.0.0",
"botocore>=1.29.0,<2.0.0",
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
]

a2a = [
Expand Down
98 changes: 51 additions & 47 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.signals import AgentSignal, InitSignalLoopSignal, ModelStreamSignal, StopSignal, ToolResultSignal
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
Expand Down Expand Up @@ -136,19 +137,19 @@ def caller(
"input": kwargs.copy(),
}

async def acall() -> ToolResult:
async def acall() -> ToolResultSignal:
# Pass kwargs as invocation_state
async for event in run_tool(self._agent, tool_use, kwargs):
_ = event
async for signal in run_tool(self._agent, tool_use, kwargs):
_ = signal

return cast(ToolResult, event)
return cast(ToolResultSignal, signal)

def tcall() -> ToolResult:
def tcall() -> ToolResultSignal:
return asyncio.run(acall())

with ThreadPoolExecutor() as executor:
future = executor.submit(tcall)
tool_result = future.result()
last_signal = future.result()

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
Expand All @@ -157,12 +158,12 @@ def tcall() -> ToolResult:

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
self._agent._record_tool_execution(tool_use, last_signal.tool_result, user_message_override)

# Apply window management
self._agent.conversation_manager.apply_management(self._agent)

return tool_result
return last_signal.tool_result

return caller

Expand Down Expand Up @@ -394,11 +395,11 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
events = self.stream_async(prompt, **kwargs)
async for event in events:
_ = event
signals = self.stream_async(prompt, **kwargs)
async for signal in signals:
_ = signal

return cast(AgentResult, event["result"])
return cast(AgentResult, signal["result"])

def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
"""This method allows you to get structured output from the agent.
Expand Down Expand Up @@ -476,19 +477,23 @@ async def structured_output_async(
"gen_ai.system.message",
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
)
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))
signals = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for signal in signals:
if "stop" not in signal:
typed_signal = ModelStreamSignal(delta_data=signal)
typed_signal._prepare_and_invoke(
agent=self, invocation_state={}, callback_handler=self.callback_handler
)

structured_output_span.add_event(
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
"gen_ai.choice", attributes={"message": serialize(signal["output"].model_dump())}
)
return event["output"]
return signal["output"]

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[AgentSignal]:
"""Process a natural language prompt and yield events as an async iterator.

This method provides an asynchronous interface for streaming agent events, allowing
Expand Down Expand Up @@ -527,25 +532,24 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
self.trace_span = self._start_agent_trace_span(message)
with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(message, invocation_state=kwargs)
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]
invocation_state = kwargs
signals = self._run_loop(message, invocation_state=invocation_state)
async for signal in signals:
signal._prepare_and_invoke(
agent=self, invocation_state=invocation_state, callback_handler=callback_handler
)
yield signal

result = AgentResult(*event["stop"])
callback_handler(result=result)
yield {"result": result}
if not isinstance(signal, StopSignal):
raise ValueError(f"Last signal was not a StopSignal; was: {signal}")

self._end_agent_trace_span(response=result)
self._end_agent_trace_span(response=signal.result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise

async def _run_loop(
self, message: Message, invocation_state: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> AsyncGenerator[AgentSignal, None]:
"""Execute the agent's event loop with the given message and parameters.

Args:
Expand All @@ -558,33 +562,33 @@ async def _run_loop(
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
yield {"callback": {"init_event_loop": True, **invocation_state}}
yield InitSignalLoopSignal()

self._append_message(message)

# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(invocation_state)
async for event in events:
signals = self._execute_event_loop_cycle(invocation_state)
async for signal in signals:
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
event.get("callback")
and event["callback"].get("event")
and event["callback"]["event"].get("redactContent")
and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
isinstance(signal, ModelStreamSignal)
and signal.delta_data.get("event")
and signal.delta_data["event"].get("redactContent")
and signal.delta_data["event"]["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = [
{"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
{"text": signal.delta_data["event"]["redactContent"]["redactUserContentMessage"]}
]
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
yield event
yield signal

finally:
self.conversation_manager.apply_management(self)
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[AgentSignal, None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand All @@ -599,12 +603,12 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A

try:
# Execute the main event loop cycle
events = event_loop_cycle(
signals = event_loop_cycle(
agent=self,
invocation_state=invocation_state,
)
async for event in events:
yield event
async for signal in signals:
yield signal

except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
Expand All @@ -614,9 +618,9 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
if self._session_manager:
self._session_manager.sync_agent(self)

events = self._execute_event_loop_cycle(invocation_state)
async for event in events:
yield event
signals = self._execute_event_loop_cycle(invocation_state)
async for signal in signals:
yield signal

def _record_tool_execution(
self,
Expand Down
Loading