From 5b37108270088278f8e626c32708df63ddeb47db Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Thu, 14 Aug 2025 16:22:33 -0400 Subject: [PATCH 1/7] First cut --- src/strands/agent/agent.py | 3 +- src/strands/event_loop/event_loop.py | 10 +++---- src/strands/types/event_loop.py | 45 ++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 43b5cbf8c..dc30a4e57 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,6 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages +from ..types.event_loop import InitEventLoopEvent from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue @@ -558,7 +559,7 @@ async def _run_loop( self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield {"callback": {"init_event_loop": True, **invocation_state}} + yield InitEventLoopEvent() self._append_message(message) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b36f73155..b5dcce8a5 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,6 +28,7 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message +from ..types.event_loop import StartEventLoop, StartEventLoopEvent, MessageEvent, ForceStopEvent, EventLoopThrottleDelay from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -93,8 +94,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace - yield {"callback": {"start": True}} - yield {"callback": {"start_event_loop": True}} + yield StartEventLoopEvent() # Create tracer span for this event loop cycle tracer = get_tracer() @@ -177,7 +177,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if isinstance(e, ModelThrottledException): if attempt + 1 == MAX_ATTEMPTS: - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(e) raise e logger.debug( @@ -191,7 +191,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}} + yield EventLoopThrottleDelay(delay=current_delay) else: raise e @@ -203,7 +203,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} + yield MessageEvent(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 7be33b6fd..8cb302cfe 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -1,9 +1,11 @@ """Event loop-related type definitions for the SDK.""" - -from typing import Literal +from dataclasses import dataclass +from typing import Literal, Any from typing_extensions import TypedDict +from .content import Message + class Usage(TypedDict): """Token usage information for model interactions. @@ -46,3 +48,42 @@ class Metrics(TypedDict): - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool """ + +class TypedEvent(dict): + pass + + +@dataclass +class StartEventLoopEvent(TypedEvent): + + def as_callback(self) -> list[dict[str, Any]]: + return [{"start": True}, + {"start_event_loop": True}] + +@dataclass +class InitEventLoopEvent(TypedEvent): + + def as_callback(self) -> dict[str, Any]: + return {"init_event_loop": True} + +@dataclass +class ForceStopEvent(TypedEvent): + + reason: str | Exception + + def as_callback(self) -> dict[str, Any]: + return {"force_stop": True, "force_stop_reason": str(self.reason)} + +@dataclass +class MessageEvent(TypedEvent): + message: Message + + def as_callback(self) -> dict[str, Any]: + return {"message": self.message} + +@dataclass +class EventLoopThrottleDelay(TypedEvent): + delay: int + + def as_callback(self) -> dict[str, Any]: + return {"event_loop_throttled_delay": self.delay} \ No newline at end of file From e23a34c79f4cd74a8c8b33c00b08ca496c28c040 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 18 Aug 2025 14:03:55 -0400 Subject: [PATCH 2/7] 2nd cut --- pyproject.toml | 2 +- src/strands/agent/agent.py | 39 +++++---- src/strands/event_loop/event_loop.py | 57 ++++++++---- src/strands/event_loop/streaming.py | 10 +-- src/strands/types/event_loop.py | 124 +++++++++++++++++++++++++-- tests/strands/agent/test_agent.py | 74 +++++++++++++--- 6 files changed, 249 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4a4b79dc..7424a763a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ writer = [ sagemaker = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0" + "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", ] a2a = [ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index dc30a4e57..ce96b7539 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.event_loop import InitEventLoopEvent +from ..types.event_loop import InitEventLoopEvent, StopEvent, StreamDeltaEvent, ResultEvent, StreamChunkEvent from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue @@ -479,8 +479,10 @@ async def structured_output_async( ) events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: - if "callback" in event: - self.callback_handler(**cast(dict, event["callback"])) + if "stop" not in event: + typed_event = StreamDeltaEvent(delta_data=event) + typed_event._invoke_callback(self.callback_handler) + structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} ) @@ -528,15 +530,22 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A self.trace_span = self._start_agent_trace_span(message) with trace_api.use_span(self.trace_span): try: - events = self._run_loop(message, invocation_state=kwargs) + invocation_state = kwargs + events = self._run_loop(message, invocation_state=invocation_state) async for event in events: - if "callback" in event: - callback_handler(**event["callback"]) - yield event["callback"] + if not isinstance(event, StopEvent): + event._invoke_callback(callback_handler, invocation_state=invocation_state) + yield event + + if not isinstance(event, StopEvent): + raise ValueError(f"Last event was not a StopEvent; was: {event}") + + event = cast(StopEvent, event) - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield {"result": result} + result = event.result + result_event = ResultEvent(result=result) + result_event._invoke_callback(callback_handler, invocation_state=invocation_state) + yield result_event self._end_agent_trace_span(response=result) @@ -569,13 +578,13 @@ async def _run_loop( # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. if ( - event.get("callback") - and event["callback"].get("event") - and event["callback"]["event"].get("redactContent") - and event["callback"]["event"]["redactContent"].get("redactUserContentMessage") + isinstance(event, StreamDeltaEvent) + and event.delta_data.get("event") + and event.delta_data["event"].get("redactContent") + and event.delta_data["event"]["redactContent"].get("redactUserContentMessage") ): self.messages[-1]["content"] = [ - {"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]} + {"text": event.delta_data["event"]["redactContent"]["redactUserContentMessage"]} ] if self._session_manager: self._session_manager.redact_latest_message(self.messages[-1], self) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b5dcce8a5..df232a11a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,6 +15,7 @@ from opentelemetry import trace as trace_api +from experimental.hooks.test_events import tool_result from ..experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -28,7 +29,8 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.event_loop import StartEventLoop, StartEventLoopEvent, MessageEvent, ForceStopEvent, EventLoopThrottleDelay +from ..types.event_loop import StartEventLoopEvent, MessageEvent, ForceStopEvent, EventLoopThrottleDelay, StopEvent, \ + StartEvent, ToolResultMessageEvent, StreamStopEvent, StreamDeltaEvent, ToolResultEvent, TypedEvent from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -94,6 +96,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace + yield StartEvent() yield StartEventLoopEvent() # Create tracer span for this event loop cycle @@ -136,13 +139,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # before yielding to the callback handler. This will be revisited when migrating to strongly # typed events. async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if "callback" in event: - yield { - "callback": { - **event["callback"], - **(invocation_state if "delta" in event["callback"] else {}), - } - } + # NEED TO DISCUSS - what events *shouldn't* be yield back? + if "stop" not in event: + yield StreamDeltaEvent(delta_data=event) stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -238,7 +237,16 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> invocation_state=invocation_state, ) async for event in events: - yield event + # Note: Once we allow streaming of tool results, we need to do a check here for it + if isinstance(event, TypedEvent): + yield event + elif isinstance(event, dict): + if "content" in event and "status" in event and "toolUseId" in event: + yield ToolResultEvent(tool_result=event) + else: + print("tool_use", event) + yield event + return @@ -266,11 +274,16 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(reason=str(e)) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield StopEvent( + stop_reason=stop_reason, + message=message, + metrics=agent.event_loop_metrics, + request_state=invocation_state["request_state"] + ) async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: @@ -297,7 +310,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) cycle_trace.add_child(recursive_trace) - yield {"callback": {"start": True}} + yield StartEvent() events = event_loop_cycle(agent=agent, invocation_state=invocation_state) async for event in events: @@ -383,7 +396,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str result=result, ) ) - yield after_event.result + yield ToolResultEvent(tool_result=after_event.result) return async for event in selected_tool.stream(tool_use, invocation_state): @@ -419,7 +432,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str exception=e, ) ) - yield after_event.result + yield ToolResultEvent(tool_result=after_event.result) async def _handle_tool_execution( @@ -459,7 +472,12 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) if not tool_uses: - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield StopEvent( + stop_reason=stop_reason, + message=message, + metrics=agent.event_loop_metrics, + request_state=invocation_state["request_state"] + ) return def tool_handler(tool_use: ToolUse) -> ToolGenerator: @@ -487,7 +505,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield {"callback": {"message": tool_result_message}} + yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: tracer = get_tracer() @@ -495,7 +513,12 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: if invocation_state["request_state"].get("stop_event_loop", False): agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield StopEvent( + stop_reason=stop_reason, + message=message, + metrics=agent.event_loop_metrics, + request_state=invocation_state["request_state"] + ) return events = recurse_event_loop(agent=agent, invocation_state=invocation_state) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 74cadaf9e..016012135 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -122,11 +122,11 @@ def handle_content_block_delta( state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} + callback_event = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} elif "text" in delta_content: state["text"] += delta_content["text"] - callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} + callback_event = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,7 +134,7 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_event["callback"] = { + callback_event = { "reasoningText": delta_content["reasoningContent"]["text"], "delta": delta_content, "reasoning": True, @@ -145,7 +145,7 @@ def handle_content_block_delta( state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event["callback"] = { + callback_event= { "reasoning_signature": delta_content["reasoningContent"]["signature"], "delta": delta_content, "reasoning": True, @@ -271,7 +271,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d metrics: Metrics = Metrics(latencyMs=0) async for chunk in chunks: - yield {"callback": {"event": chunk}} + yield {"event": chunk} if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 8cb302cfe..c0f5ff26e 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -1,10 +1,15 @@ """Event loop-related type definitions for the SDK.""" from dataclasses import dataclass -from typing import Literal, Any +from typing import Literal, Any, TYPE_CHECKING, Callable from typing_extensions import TypedDict + from .content import Message +from .tools import ToolResult + +if TYPE_CHECKING: + from ..telemetry import EventLoopMetrics class Usage(TypedDict): @@ -49,20 +54,43 @@ class Metrics(TypedDict): - "tool_use": Model requested to use a tool """ +_sentinal_value = 34235234523423 + class TypedEvent(dict): - pass + def __post_init__(self): + if hasattr(self, "as_callback"): + props = self.as_callback() + self.update(**props) + + def supports_callback(self): + return hasattr(self, "as_callback") + + def _invoke_callback(self, callback_handler: Callable, invocation_state = {}) -> None: + args = self.get_callback_args(invocation_state) + if args: + callback_handler(**args) + + def get_callback_args(self, invocation_state = {}): + if hasattr(self, "as_callback"): + if hasattr(self, "include_state"): + self.update(**invocation_state) + args = self.as_callback() + return args + return {} @dataclass class StartEventLoopEvent(TypedEvent): def as_callback(self) -> list[dict[str, Any]]: - return [{"start": True}, - {"start_event_loop": True}] + return {"start_event_loop": True} @dataclass class InitEventLoopEvent(TypedEvent): + def include_state(self): + return True + def as_callback(self) -> dict[str, Any]: return {"init_event_loop": True} @@ -86,4 +114,90 @@ class EventLoopThrottleDelay(TypedEvent): delay: int def as_callback(self) -> dict[str, Any]: - return {"event_loop_throttled_delay": self.delay} \ No newline at end of file + return {"event_loop_throttled_delay": self.delay} + +@dataclass +class StopEvent(TypedEvent): + stop_reason: StopReason + message: Message + metrics: "EventLoopMetrics" + request_state: Any + + @property + def result(self) -> "AgentResult": + from strands.agent import AgentResult + + return AgentResult( + stop_reason=self.stop_reason, + message=self.message, + metrics=self.metrics, + state=self.request_state, + ) + + # def as_callback(self) -> dict[str, Any]: + # return { "stop": (self.stop_reason, self.message, self.metrics, self.request_state)} + +@dataclass +class StartEvent(TypedEvent): + + + + def as_callback(self) -> dict[str, Any]: + return {"start": True} + +@dataclass +class ToolResultEvent(TypedEvent): + tool_result: ToolResult + +@dataclass +class ToolResultMessageEvent(TypedEvent): + message: Any + + def as_callback(self) -> dict[str, Any]: + return {"message": self.message} + +@dataclass +class StreamChunkEvent(TypedEvent): + chunk: Any + + def as_callback(self) -> dict[str, Any]: + return {"event": self.chunk} + +@dataclass +class StreamDeltaEvent(TypedEvent): + delta_data: dict[str, Any] + + def include_state(self): + return True + + def as_callback(self) -> dict[str, Any]: + # TODO include invocation state in here + return self.delta_data + + def _invoke_callback(self, callback_handler: Callable, invocation_state = {}) -> None: + + args = self.as_callback() + + if "delta" in self.delta_data: + all_args = {**invocation_state, **args} + else: + all_args = {**args} + + callback_handler(**all_args) + +@dataclass +class ResultEvent(TypedEvent): + result: "AgentResult" + + def as_callback(self) -> dict[str, Any]: + return {"result": self.result} + +@dataclass +class StreamStopEvent(TypedEvent): + stop_reason: StopReason + message: Any + usage: Any + metrics: Any + + def as_dict(self) -> dict[str, Any]: + return {"stop": (self.stop_reason, self.message, self.usage, self.metrics)} \ No newline at end of file diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fdce7c368..1ac71fc93 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,6 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages +from strands.types.event_loop import StreamStopEvent, StopEvent, StreamDeltaEvent, InitEventLoopEvent, ResultEvent from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository @@ -399,7 +400,12 @@ async def check_invocation_state(**kwargs): assert invocation_state["agent"] == agent # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield StopEvent( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + request_state={}, + metrics={}, + ) mock_event_loop_cycle.side_effect = check_invocation_state @@ -661,8 +667,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ) agent("test") - callback_handler.assert_has_calls( - [ + assert callback_handler.call_args_list == [ unittest.mock.call(init_event_loop=True), unittest.mock.call(start=True), unittest.mock.call(start_event_loop=True), @@ -726,8 +731,22 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ], }, ), - ], - ) + unittest.mock.call( + result=AgentResult( + stop_reason="end_turn", + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + metrics=unittest.mock.ANY, + state=unittest.mock.ANY, + ), + ) + ] @pytest.mark.asyncio @@ -1140,20 +1159,38 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): # Define the side effect to simulate callback handler being called multiple times async def test_event_loop(*args, **kwargs): - yield {"callback": {"data": "First chunk"}} - yield {"callback": {"data": "Second chunk"}} - yield {"callback": {"data": "Final chunk", "complete": True}} + yield StreamDeltaEvent(delta_data={"data": "First chunk"}) + yield StreamDeltaEvent(delta_data={"data": "Second chunk"}) + yield StreamDeltaEvent(delta_data={"data": "Final chunk", "complete": True}) # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield StopEvent( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + request_state={}, + ) mock_event_loop_cycle.side_effect = test_event_loop mock_callback = unittest.mock.Mock() stream = agent.stream_async("test message", callback_handler=mock_callback) + # TODO tru_events = await alist(stream) exp_events = [ + InitEventLoopEvent(), + StreamDeltaEvent(delta_data={"data": "First chunk"}), + StreamDeltaEvent(delta_data={"data": "Second chunk"}), + StreamDeltaEvent(delta_data={"data": "Final chunk", "complete": True}), + ResultEvent(result=AgentResult(stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + state={})) + ] + assert tru_events == exp_events + + assert [{**it} for it in tru_events] == [ {"init_event_loop": True, "callback_handler": mock_callback}, {"data": "First chunk"}, {"data": "Second chunk"}, @@ -1167,10 +1204,9 @@ async def test_event_loop(*args, **kwargs): ), }, ] - assert tru_events == exp_events - exp_calls = [unittest.mock.call(**event) for event in exp_events] - mock_callback.assert_has_calls(exp_calls) + exp_calls = [unittest.mock.call(**event.get_callback_args()) for event in exp_events] + assert mock_callback.call_args_list == exp_calls @pytest.mark.asyncio @@ -1230,7 +1266,12 @@ async def check_invocation_state(**kwargs): invocation_state = kwargs["invocation_state"] assert invocation_state["some_value"] == "a_value" # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield StopEvent( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + request_state={}, + ) mock_event_loop_cycle.side_effect = check_invocation_state @@ -1362,7 +1403,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_get_tracer.return_value = mock_tracer async def test_event_loop(*args, **kwargs): - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})} + yield StopEvent( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Agent Response"}]}, + metrics={}, + request_state={}, + ) mock_event_loop_cycle.side_effect = test_event_loop From e3e8f435ab6f3d1f4425f329b0d88359eeeed641 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Mon, 18 Aug 2025 14:15:22 -0400 Subject: [PATCH 3/7] 3rd cut --- src/strands/agent/agent.py | 2 +- src/strands/types/event_loop.py | 17 +++++------------ tests/strands/event_loop/test_event_loop.py | 13 +++++++------ 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ce96b7539..95de2f0ed 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.event_loop import InitEventLoopEvent, StopEvent, StreamDeltaEvent, ResultEvent, StreamChunkEvent +from ..types.event_loop import InitEventLoopEvent, StopEvent, StreamDeltaEvent, ResultEvent from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index c0f5ff26e..6c69f8212 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -61,6 +61,9 @@ def __post_init__(self): if hasattr(self, "as_callback"): props = self.as_callback() self.update(**props) + elif hasattr(self, "as_dict"): + props = self.as_dict() + self.update(**props) def supports_callback(self): return hasattr(self, "as_callback") @@ -134,14 +137,11 @@ def result(self) -> "AgentResult": state=self.request_state, ) - # def as_callback(self) -> dict[str, Any]: - # return { "stop": (self.stop_reason, self.message, self.metrics, self.request_state)} + def as_dict(self) -> dict[str, Any]: + return { "stop": (self.stop_reason, self.message, self.metrics, self.request_state)} @dataclass class StartEvent(TypedEvent): - - - def as_callback(self) -> dict[str, Any]: return {"start": True} @@ -156,13 +156,6 @@ class ToolResultMessageEvent(TypedEvent): def as_callback(self) -> dict[str, Any]: return {"message": self.message} -@dataclass -class StreamChunkEvent(TypedEvent): - chunk: Any - - def as_callback(self) -> dict[str, Any]: - return {"event": self.chunk} - @dataclass class StreamDeltaEvent(TypedEvent): delta_data: dict[str, Any] diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 191ab51ba..c3865246e 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,6 +19,7 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry +from strands.types.event_loop import ForceStopEvent, ToolResultEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -483,7 +484,7 @@ async def test_cycle_exception( ] tru_stop_event = None - exp_stop_event = {"callback": {"force_stop": True, "force_stop_reason": "Invalid error presented"}} + exp_stop_event = ForceStopEvent(reason="Invalid error presented") with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( @@ -836,11 +837,11 @@ async def test_run_tool_missing_tool(agent, alist): tru_events = await alist(process) exp_events = [ - { + ToolResultEvent(tool_result={ "toolUseId": "missing", "status": "error", "content": [{"text": "Unknown tool: missing"}], - }, + }) ] assert tru_events == exp_events @@ -1009,7 +1010,7 @@ def modify_hook(event: AfterToolInvocationEvent): ) result = (await alist(process))[-1] - assert result == updated_result + assert result == ToolResultEvent(tool_result=updated_result) @pytest.mark.asyncio @@ -1050,11 +1051,11 @@ def after_tool_call(self, event: AfterToolInvocationEvent): result = (await alist(process))[-1] - assert result == { + assert result == ToolResultEvent(tool_result={ "status": "error", "toolUseId": "test", "content": [{"text": "This tool has been used too many times!"}], - } + }) assert mock_logger.debug.call_args_list == [ call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), From 9b2f7c0c4b3ccace8ce17ccf92dca6dcae0be613 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 19 Aug 2025 12:44:36 -0400 Subject: [PATCH 4/7] Fix all remaining tests --- src/strands/agent/agent.py | 24 +- src/strands/event_loop/event_loop.py | 58 +++-- src/strands/event_loop/streaming.py | 2 +- src/strands/tools/executor.py | 16 +- src/strands/types/event_loop.py | 149 +----------- src/strands/types/events.py | 243 ++++++++++++++++++++ tests/strands/agent/test_agent.py | 229 +++++++++++------- tests/strands/event_loop/test_event_loop.py | 42 ++-- tests/strands/event_loop/test_streaming.py | 199 +++++++--------- tests/strands/tools/test_executor.py | 53 +++-- 10 files changed, 593 insertions(+), 422 deletions(-) create mode 100644 src/strands/types/events.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 95de2f0ed..363cf4025 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.event_loop import InitEventLoopEvent, StopEvent, StreamDeltaEvent, ResultEvent +from ..types.events import InitEventLoopEvent, ResultEvent, StopEvent, StreamDeltaEvent, ToolResultEvent, TypedEvent from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue @@ -137,19 +137,19 @@ def caller( "input": kwargs.copy(), } - async def acall() -> ToolResult: + async def acall() -> ToolResultEvent: # Pass kwargs as invocation_state async for event in run_tool(self._agent, tool_use, kwargs): _ = event - return cast(ToolResult, event) + return cast(ToolResultEvent, event) - def tcall() -> ToolResult: + def tcall() -> ToolResultEvent: return asyncio.run(acall()) with ThreadPoolExecutor() as executor: future = executor.submit(tcall) - tool_result = future.result() + last_event = future.result() if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -158,12 +158,12 @@ def tcall() -> ToolResult: if should_record_direct_tool_call: # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + self._agent._record_tool_execution(tool_use, last_event, user_message_override) # Apply window management self._agent.conversation_manager.apply_management(self._agent) - return tool_result + return last_event.tool_result return caller @@ -481,7 +481,7 @@ async def structured_output_async( async for event in events: if "stop" not in event: typed_event = StreamDeltaEvent(delta_data=event) - typed_event._invoke_callback(self.callback_handler) + typed_event.prepare_and_invoke(invocation_state={}, callback_handler=self.callback_handler) structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} @@ -534,7 +534,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A events = self._run_loop(message, invocation_state=invocation_state) async for event in events: if not isinstance(event, StopEvent): - event._invoke_callback(callback_handler, invocation_state=invocation_state) + event.prepare_and_invoke(invocation_state, callback_handler) yield event if not isinstance(event, StopEvent): @@ -544,7 +544,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A result = event.result result_event = ResultEvent(result=result) - result_event._invoke_callback(callback_handler, invocation_state=invocation_state) + result_event.prepare_and_invoke(invocation_state, callback_handler) yield result_event self._end_agent_trace_span(response=result) @@ -553,9 +553,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A self._end_agent_trace_span(error=e) raise - async def _run_loop( - self, message: Message, invocation_state: dict[str, Any] - ) -> AsyncGenerator[dict[str, Any], None]: + async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. Args: diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index df232a11a..d51c70b8a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,7 +15,6 @@ from opentelemetry import trace as trace_api -from experimental.hooks.test_events import tool_result from ..experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -29,8 +28,20 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.event_loop import StartEventLoopEvent, MessageEvent, ForceStopEvent, EventLoopThrottleDelay, StopEvent, \ - StartEvent, ToolResultMessageEvent, StreamStopEvent, StreamDeltaEvent, ToolResultEvent, TypedEvent +from ..types.events import ( + EventLoopThrottleDelay, + ForceStopEvent, + MessageEvent, + StartEvent, + StartEventLoopEvent, + StopEvent, + StreamDeltaEvent, + ToolResultEvent, + ToolResultMessageEvent, + ToolStreamEvent, + TypedEvent, + TypedToolGenerator, +) from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -38,7 +49,7 @@ ModelThrottledException, ) from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages @@ -52,7 +63,7 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -237,17 +248,13 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> invocation_state=invocation_state, ) async for event in events: + print("evloop", event) # Note: Once we allow streaming of tool results, we need to do a check here for it if isinstance(event, TypedEvent): yield event - elif isinstance(event, dict): - if "content" in event and "status" in event and "toolUseId" in event: - yield ToolResultEvent(tool_result=event) else: - print("tool_use", event) - yield event - - + logger.warning(f"event=<{event} | non-typed event") + raise ValueError("Event was not a subclass of TypedEvent") return # End the cycle and return results @@ -282,11 +289,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, - request_state=invocation_state["request_state"] + request_state=invocation_state["request_state"], ) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -319,7 +326,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace.end() -async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: +async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> TypedToolGenerator: """Process a tool invocation. Looks up the tool in the registry and streams it with the provided parameters. @@ -400,8 +407,12 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str return async for event in selected_tool.stream(tool_use, invocation_state): - yield event + # wrap all events in a ToolStreamEvent + print(event) + yield ToolStreamEvent(tool_use, event) + if not ("content" in event and "status" in event and "toolUseId" in event): + raise ValueError("Last event of tool invocation was not a tool result") result = event after_event = agent.hooks.invoke_callbacks( @@ -413,7 +424,8 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str result=result, ) ) - yield after_event.result + + yield ToolResultEvent(tool_result=after_event.result) except Exception as e: logger.exception("tool_name=<%s> | failed to process tool", tool_name) @@ -443,7 +455,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], -) -> AsyncGenerator[dict[str, Any], None]: +) -> AsyncGenerator[TypedEvent, None]: tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] @@ -476,11 +488,11 @@ async def _handle_tool_execution( stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, - request_state=invocation_state["request_state"] + request_state=invocation_state["request_state"], ) return - def tool_handler(tool_use: ToolUse) -> ToolGenerator: + def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: return run_tool(agent, tool_use, invocation_state) tool_events = run_tools( @@ -493,6 +505,10 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: parent_span=cycle_span, ) async for tool_event in tool_events: + # For now, until we implement #543, supress tool stream events + if isinstance(tool_event, ToolStreamEvent): + continue + yield tool_event # Store parent cycle ID for the next cycle @@ -517,7 +533,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, - request_state=invocation_state["request_state"] + request_state=invocation_state["request_state"], ) return diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 016012135..749537627 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -145,7 +145,7 @@ def handle_content_block_delta( state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event= { + callback_event = { "reasoning_signature": delta_content["reasoningContent"]["signature"], "delta": delta_content, "reasoning": True, diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index d90f9a5aa..1c3fc990e 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -11,7 +11,8 @@ from ..telemetry.tracer import get_tracer from ..tools.tools import InvalidToolUseNameException, validate_tool_use from ..types.content import Message -from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse +from ..types.events import ToolResultEvent, TypedToolGenerator +from ..types.tools import RunToolHandler, ToolResult, ToolUse logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ async def run_tools( tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace_api.Span] = None, -) -> ToolGenerator: +) -> TypedToolGenerator: """Execute tools concurrently. Args: @@ -60,19 +61,20 @@ async def work( await worker_event.wait() worker_event.clear() - result = cast(ToolResult, event) + result_event = cast(ToolResultEvent, event) finally: worker_queue.put_nowait((worker_id, stop_event)) - tool_success = result.get("status") == "success" + tool_success = result_event.get("status") == "success" + tool_result = result_event.tool_result tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) + message = Message(role="user", content=[{"toolResult": tool_result}]) event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) - tracer.end_tool_call_span(tool_call_span, result) + tracer.end_tool_call_span(tool_call_span, tool_result) - return result + return tool_result tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 6c69f8212..62dbef0b3 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -1,15 +1,11 @@ """Event loop-related type definitions for the SDK.""" -from dataclasses import dataclass -from typing import Literal, Any, TYPE_CHECKING, Callable - -from typing_extensions import TypedDict +from typing import TYPE_CHECKING, Literal -from .content import Message -from .tools import ToolResult +from typing_extensions import TypedDict if TYPE_CHECKING: - from ..telemetry import EventLoopMetrics + pass class Usage(TypedDict): @@ -55,142 +51,3 @@ class Metrics(TypedDict): """ _sentinal_value = 34235234523423 - -class TypedEvent(dict): - def __post_init__(self): - if hasattr(self, "as_callback"): - props = self.as_callback() - self.update(**props) - elif hasattr(self, "as_dict"): - props = self.as_dict() - self.update(**props) - - def supports_callback(self): - return hasattr(self, "as_callback") - - def _invoke_callback(self, callback_handler: Callable, invocation_state = {}) -> None: - args = self.get_callback_args(invocation_state) - if args: - callback_handler(**args) - - def get_callback_args(self, invocation_state = {}): - if hasattr(self, "as_callback"): - if hasattr(self, "include_state"): - self.update(**invocation_state) - args = self.as_callback() - return args - - return {} - -@dataclass -class StartEventLoopEvent(TypedEvent): - - def as_callback(self) -> list[dict[str, Any]]: - return {"start_event_loop": True} - -@dataclass -class InitEventLoopEvent(TypedEvent): - - def include_state(self): - return True - - def as_callback(self) -> dict[str, Any]: - return {"init_event_loop": True} - -@dataclass -class ForceStopEvent(TypedEvent): - - reason: str | Exception - - def as_callback(self) -> dict[str, Any]: - return {"force_stop": True, "force_stop_reason": str(self.reason)} - -@dataclass -class MessageEvent(TypedEvent): - message: Message - - def as_callback(self) -> dict[str, Any]: - return {"message": self.message} - -@dataclass -class EventLoopThrottleDelay(TypedEvent): - delay: int - - def as_callback(self) -> dict[str, Any]: - return {"event_loop_throttled_delay": self.delay} - -@dataclass -class StopEvent(TypedEvent): - stop_reason: StopReason - message: Message - metrics: "EventLoopMetrics" - request_state: Any - - @property - def result(self) -> "AgentResult": - from strands.agent import AgentResult - - return AgentResult( - stop_reason=self.stop_reason, - message=self.message, - metrics=self.metrics, - state=self.request_state, - ) - - def as_dict(self) -> dict[str, Any]: - return { "stop": (self.stop_reason, self.message, self.metrics, self.request_state)} - -@dataclass -class StartEvent(TypedEvent): - def as_callback(self) -> dict[str, Any]: - return {"start": True} - -@dataclass -class ToolResultEvent(TypedEvent): - tool_result: ToolResult - -@dataclass -class ToolResultMessageEvent(TypedEvent): - message: Any - - def as_callback(self) -> dict[str, Any]: - return {"message": self.message} - -@dataclass -class StreamDeltaEvent(TypedEvent): - delta_data: dict[str, Any] - - def include_state(self): - return True - - def as_callback(self) -> dict[str, Any]: - # TODO include invocation state in here - return self.delta_data - - def _invoke_callback(self, callback_handler: Callable, invocation_state = {}) -> None: - - args = self.as_callback() - - if "delta" in self.delta_data: - all_args = {**invocation_state, **args} - else: - all_args = {**args} - - callback_handler(**all_args) - -@dataclass -class ResultEvent(TypedEvent): - result: "AgentResult" - - def as_callback(self) -> dict[str, Any]: - return {"result": self.result} - -@dataclass -class StreamStopEvent(TypedEvent): - stop_reason: StopReason - message: Any - usage: Any - metrics: Any - - def as_dict(self) -> dict[str, Any]: - return {"stop": (self.stop_reason, self.message, self.usage, self.metrics)} \ No newline at end of file diff --git a/src/strands/types/events.py b/src/strands/types/events.py new file mode 100644 index 000000000..78a8229ea --- /dev/null +++ b/src/strands/types/events.py @@ -0,0 +1,243 @@ +from dataclasses import MISSING +from typing import Any, AsyncGenerator, Callable, Literal, override + +from ..telemetry import EventLoopMetrics +from .content import Message +from .event_loop import StopReason +from .tools import ToolResult, ToolUse + + +class TypedEvent(dict): + # Do I have to be careful about tool stream results? + invocation_state: dict[str, Any] + + def __init__(self, data: dict[str, Any] = None): + super().__init__(data or {}) + + def _get_callback_fields(self) -> list[str] | Literal["all"] | None: + return None + + def set_invocation_state(self, invocation_state: dict): + self.invocation_state = invocation_state + + def prepare_and_invoke(self, invocation_state: dict, callback_handler: Callable): + self.set_invocation_state(invocation_state) + + # invoking + args = TypedEvent.get_callback_fields(self) + + if args: + callback_handler(**args) + return True + + return False + + @staticmethod + def get_callback_fields(event: "TypedEvent") -> dict[str, Any] | None: + allow_listed = event._get_callback_fields() + + if allow_listed is None: + return None + else: + if allow_listed is Any: + return {**event} + else: + return {k: v for k, v in event.items() if k in allow_listed} + + +class StartEventLoopEvent(TypedEvent): + def __init__(self): + super().__init__({"start_event_loop": True}) + + def _get_callback_fields(self): + return ["start_event_loop"] + + +class InitEventLoopEvent(TypedEvent): + def __init__(self): + super().__init__({"init_event_loop": True}) + + @override + def set_invocation_state(self, invocation_state: dict): + super().set_invocation_state(invocation_state) + + # For backwards compatability, make sure that we're merging the + # invocation state as a readonly copy into ourselves + self.update(**invocation_state) + + def _get_callback_fields(self): + return ["init_event_loop"] + + +class ForceStopEvent(TypedEvent): + @property + def reason(self) -> str: + return self.get("force_stop_reason") + + @property + def reason_exception(self) -> Exception | None: + return self.get("force_stop_reason_exception") + + def __init__(self, reason: str | Exception): + super().__init__( + {"force_stop": True, "force_stop_reason": str(reason), "force_stop_reason_exception": reason or MISSING} + ) + + def _get_callback_fields(self): + return ["force_stop", "force_stop_reason"] + + +class MessageEvent(TypedEvent): + def __init__(self, message: Message): + super().__init__({"message": message}) + + @property + def message(self) -> Message: + return self.get("message") + + def _get_callback_fields(self): + return ["message"] + + +class EventLoopThrottleDelay(TypedEvent): + def __init__(self, delay: int): + super().__init__({"event_loop_throttled_delay": delay}) + + @property + def delay(self) -> Message: + return self.get("event_loop_throttled_delay") + + def _get_callback_fields(self): + return ["event_loop_throttled_delay"] + + +class StopEvent(TypedEvent): + def __init__( + self, + stop_reason: StopReason, + message: Message, + metrics: "EventLoopMetrics", + request_state: Any, + ): + super().__init__( + { + "stop": (stop_reason, message, metrics, request_state), + } + ) + + @property + def stop_reason(self) -> StopReason: + return self.get("stop")[0] + + @property + def message(self) -> Message: + return self.get("stop")[1] + + @property + def metrics(self) -> "EventLoopMetrics": + return self.get("stop")[2] + + @property + def request_state(self) -> Any: + return self.get("stop")[3] + + @property + def result(self) -> "AgentResult": + from strands.agent import AgentResult + + return AgentResult( + stop_reason=self.stop_reason, + message=self.message, + metrics=self.metrics, + state=self.request_state, + ) + + def _get_callback_fields(self): + return ["stop"] + + +class StartEvent(TypedEvent): + def __init__(self): + super().__init__({"start": True}) + + def _get_callback_fields(self): + return ["start"] + + +class ToolStreamEvent(TypedEvent): + def __init__(self, tool_use: ToolUse, stream_data: any): + super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": stream_data}) + + @property + def tool_use(self) -> ToolUse: + return self.get("tool_stream_tool_use") + + @property + def tool_stream_data(self) -> Any: + return self.get("tool_stream_data") + + +class ToolResultEvent(TypedEvent): + def __init__(self, tool_result: ToolResult): + super().__init__({"tool_result": tool_result}) + + @property + def tool_result(self) -> ToolResult: + return self.get("tool_result") + + +class ToolResultMessageEvent(TypedEvent): + def __init__(self, message: Any): + super().__init__({"message": message}) + + @property + def message(self) -> Any: + return self.get("message") + + @override + def _get_callback_fields(self): + return ["message"] + + +class StreamDeltaEvent(TypedEvent): + def __init__(self, delta_data: dict[str, Any]): + super().__init__(delta_data) + + @property + def delta_data(self) -> Any: + return self + + def include_state(self): + return True + + def as_callback(self) -> dict[str, Any]: + return self.delta_data + + def set_invocation_state(self, invocation_state: dict): + super().set_invocation_state(invocation_state) + + # For backwards compatability, make sure that we're merging the + # invocation state as a readonly copy into ourselves + if "delta" in self.delta_data: + self.update(**invocation_state) + + @override + def _get_callback_fields(self): + return Any + + +class ResultEvent(TypedEvent): + def __init__(self, result: "AgentResult"): + super().__init__({"result": result}) + + @property + def result(self) -> "AgentResult": + return self.get("result") + + @override + def _get_callback_fields(self): + return ["result"] + + +TypedToolGenerator = AsyncGenerator[TypedEvent, None] +"""Generator of tool events where all events are typed as TypedEvents.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1ac71fc93..125bbe718 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +from unittest.mock import ANY from uuid import uuid4 import pytest @@ -19,7 +20,13 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.event_loop import StreamStopEvent, StopEvent, StreamDeltaEvent, InitEventLoopEvent, ResultEvent +from strands.types.events import ( + InitEventLoopEvent, + ResultEvent, + StopEvent, + StreamDeltaEvent, + ToolResultEvent, +) from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository @@ -35,6 +42,8 @@ def mock_randint(): @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): + print("SA", type(args), args) + print("SK", type(kwargs), kwargs) result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): @@ -668,60 +677,70 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): agent("test") assert callback_handler.call_args_list == [ - unittest.mock.call(init_event_loop=True), - unittest.mock.call(start=True), - unittest.mock.call(start_event_loop=True), - unittest.mock.call( - event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}} - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), - unittest.mock.call( - agent=agent, - current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, - delta={"toolUse": {"input": '{"value"}'}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"text": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoningText="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"signature": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoning_signature="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), - unittest.mock.call( - agent=agent, - data="value", - delta={"text": "value"}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call( + unittest.mock.call(init_event_loop=True), + unittest.mock.call(start=True), + unittest.mock.call(start_event_loop=True), + unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + unittest.mock.call( + agent=agent, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + delta={"toolUse": {"input": '{"value"}'}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"text": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoningText="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"signature": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoning_signature="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + unittest.mock.call( + agent=agent, + data="value", + delta={"text": "value"}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + unittest.mock.call( + result=AgentResult( + stop_reason="end_turn", message={ "role": "assistant", "content": [ @@ -730,23 +749,11 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"text": "value"}, ], }, + metrics=unittest.mock.ANY, + state=unittest.mock.ANY, ), - unittest.mock.call( - result=AgentResult( - stop_reason="end_turn", - message={ - "role": "assistant", - "content": [ - {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, - {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, - {"text": "value"}, - ], - }, - metrics=unittest.mock.ANY, - state=unittest.mock.ANY, - ), - ) - ] + ), + ] @pytest.mark.asyncio @@ -895,7 +902,7 @@ def test_agent_init_with_no_model_or_model_id(): def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) + mock_run_tool.return_value = agenerator([ToolResultEvent({})]) @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: @@ -919,7 +926,7 @@ def function(system_prompt: str) -> str: def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) + mock_run_tool.return_value = agenerator([ToolResultEvent({})]) tool_name = "system-prompter" @@ -1177,16 +1184,23 @@ async def test_event_loop(*args, **kwargs): stream = agent.stream_async("test message", callback_handler=mock_callback) # TODO + init_event = InitEventLoopEvent() + init_event.update({"callback_handler": mock_callback}) + tru_events = await alist(stream) exp_events = [ - InitEventLoopEvent(), + init_event, StreamDeltaEvent(delta_data={"data": "First chunk"}), StreamDeltaEvent(delta_data={"data": "Second chunk"}), StreamDeltaEvent(delta_data={"data": "Final chunk", "complete": True}), - ResultEvent(result=AgentResult(stop_reason="stop", + ResultEvent( + result=AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Response"}]}, metrics={}, - state={})) + state={}, + ) + ), ] assert tru_events == exp_events @@ -1205,8 +1219,67 @@ async def test_event_loop(*args, **kwargs): }, ] - exp_calls = [unittest.mock.call(**event.get_callback_args()) for event in exp_events] - assert mock_callback.call_args_list == exp_calls + # TODO MDZ + # exp_calls = [unittest.mock.call(event) for event in exp_events] + # assert mock_callback.call_args_list == exp_calls + + +@pytest.mark.asyncio +async def test_stream_async_e2e(alist): + @strands.tool + def fake_tool(agent: Agent): + return "Done!" + + mock_provider = MockedModelProvider( + [ + {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, + {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + agent = Agent(model=mock_provider, tools=[fake_tool]) + + stream = agent.stream_async("Do the stuff", arg1=1013) + + tru_events = [{**item} for item in await alist(stream)] + exp_events = [ + {"arg1": 1013, "init_event_loop": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}}, + { + "agent": ANY, + "arg1": 1013, + "data": "INPUT BLOCKED!", + "delta": {"text": "INPUT BLOCKED!"}, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "request_state": {}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + { + "result": AgentResult( + stop_reason="guardrail_intervened", + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + metrics=ANY, + state={}, + ), + }, + ] + assert tru_events == exp_events + + exp_calls = [unittest.mock.call(**event) for event in exp_events] + # mock_callback.assert_has_calls(exp_calls) @pytest.mark.asyncio diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index c3865246e..20f9895d9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,7 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.event_loop import ForceStopEvent, ToolResultEvent +from strands.types.events import ForceStopEvent, StopEvent, ToolResultEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -786,11 +786,11 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a # Set up mock to return a valid response mock_recurse.return_value = agenerator( [ - ( - "end_turn", - {"role": "assistant", "content": [{"text": "test text"}]}, - strands.telemetry.metrics.EventLoopMetrics(), - {}, + StopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "test text"}]}, + metrics=strands.telemetry.metrics.EventLoopMetrics(), + request_state={}, ), ] ) @@ -822,7 +822,7 @@ async def test_run_tool(agent, tool, alist): ) tru_result = (await alist(process))[-1] - exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]} + exp_result = ToolResultEvent({"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]}) assert tru_result == exp_result @@ -837,11 +837,13 @@ async def test_run_tool_missing_tool(agent, alist): tru_events = await alist(process) exp_events = [ - ToolResultEvent(tool_result={ - "toolUseId": "missing", - "status": "error", - "content": [{"text": "Unknown tool: missing"}], - }) + ToolResultEvent( + tool_result={ + "toolUseId": "missing", + "status": "error", + "content": [{"text": "Unknown tool: missing"}], + } + ) ] assert tru_events == exp_events @@ -957,7 +959,7 @@ def modify_hook(event: BeforeToolInvocationEvent): result = (await alist(process))[-1] # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2) - assert result == {"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]} + assert result == ToolResultEvent({"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]}) assert hook_provider.events_received[1] == AfterToolInvocationEvent( agent=agent, @@ -988,7 +990,7 @@ def modify_hook(event: AfterToolInvocationEvent): ) result = (await alist(process))[-1] - assert result == updated_result + assert result == ToolResultEvent(updated_result) @pytest.mark.asyncio @@ -1051,11 +1053,13 @@ def after_tool_call(self, event: AfterToolInvocationEvent): result = (await alist(process))[-1] - assert result == ToolResultEvent(tool_result={ - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - }) + assert result == ToolResultEvent( + tool_result={ + "status": "error", + "toolUseId": "test", + "content": [{"text": "This tool has been used too many times!"}], + } + ) assert mock_logger.debug.call_args_list == [ call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 921fd91de..6261c214d 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -143,9 +143,12 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ], ) def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): - exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} + exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {} tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) + print(callback_args) + print(tru_updated_state) + print(tru_callback_event) assert tru_updated_state == exp_updated_state assert tru_callback_event == exp_callback_event @@ -286,85 +289,71 @@ def test_extract_usage_metrics(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": { - "toolUse": { - "name": "test", - "toolUseId": "123", - }, + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", }, }, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, }, }, { - "callback": { - "current_tool_use": { - "input": { - "key": "value", - }, - "name": "test", - "toolUseId": "123", + "current_tool_use": { + "input": { + "key": "value", }, - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "tool_use", - }, + "event": { + "messageStop": { + "stopReason": "tool_use", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -387,9 +376,7 @@ def test_extract_usage_metrics(): [{}], [ { - "callback": { - "event": {}, - }, + "event": {}, }, { "stop": ( @@ -433,80 +420,64 @@ def test_extract_usage_metrics(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": {}, - }, + "event": { + "contentBlockStart": { + "start": {}, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "Hello!", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", }, }, }, }, { - "callback": { - "data": "Hello!", - "delta": { - "text": "Hello!", - }, + "data": "Hello!", + "delta": { + "text": "Hello!", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "guardrail_intervened", - }, + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", }, }, }, { - "callback": { - "event": { - "redactContent": { - "redactAssistantContentMessage": "REDACTED.", - "redactUserContentMessage": "REDACTED", - }, + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -554,29 +525,23 @@ async def test_stream_messages(agenerator, alist): tru_events = await alist(stream) exp_events = [ { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "test", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", }, }, }, }, { - "callback": { - "data": "test", - "delta": { - "text": "test", - }, + "data": "test", + "delta": { + "text": "test", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 04d4ea657..df7c710d2 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -6,6 +6,7 @@ import strands import strands.telemetry from strands.types.content import Message +from strands.types.events import ToolResultEvent, ToolStreamEvent @pytest.fixture(autouse=True) @@ -16,11 +17,13 @@ def moto_autouse(moto_env): @pytest.fixture def tool_handler(request): async def handler(tool_use): - yield {"event": "abc"} - yield { - **params, - "toolUseId": tool_use["toolUseId"], - } + yield ToolStreamEvent(tool_use, {"event": "abc"}) + yield ToolResultEvent( + { + **params, + "toolUseId": tool_use["toolUseId"], + } + ) params = { "content": [{"text": "test result"}], @@ -86,22 +89,25 @@ async def test_run_tools( tru_events = await alist(stream) exp_events = [ - {"event": "abc"}, - { - "content": [ - { - "text": "test result", - }, - ], - "status": "success", - "toolUseId": "t1", - }, + ToolStreamEvent(tool_uses[0], {"event": "abc"}), + ToolResultEvent( + { + "content": [ + { + "text": "test result", + }, + ], + "status": "success", + "toolUseId": "t1", + } + ), ] tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] - assert tru_events == exp_events and tru_results == exp_results + assert tru_events == exp_events + assert tru_results == exp_results @pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True) @@ -319,9 +325,16 @@ async def test_run_tools_creates_and_ends_span_on_success( # Verify span was ended with the tool result mock_tracer.end_tool_call_span.assert_called_once() args, _ = mock_tracer.end_tool_call_span.call_args - assert args[0] == mock_span - assert args[1]["status"] == "success" - assert args[1]["content"][0]["text"] == "test result" + assert args == ( + mock_span, + { + "status": "success", + "content": [ + {"text": "test result"}, + ], + "toolUseId": "t1", + }, + ) @unittest.mock.patch("strands.tools.executor.get_tracer") From 175254da0d6bbdb93adf87a8fe13a884fbd39c0b Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 20 Aug 2025 18:37:31 -0400 Subject: [PATCH 5/7] Revamp events + Model events --- src/strands/agent/agent.py | 28 +- src/strands/event_loop/event_loop.py | 12 +- src/strands/event_loop/streaming.py | 35 +- src/strands/types/events.py | 426 +++++++++++++++----- tests/strands/agent/test_agent.py | 29 +- tests/strands/event_loop/test_event_loop.py | 79 ++-- 6 files changed, 420 insertions(+), 189 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 363cf4025..9a2635259 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.events import InitEventLoopEvent, ResultEvent, StopEvent, StreamDeltaEvent, ToolResultEvent, TypedEvent +from ..types.events import InitEventLoopEvent, ModelStreamEvent, StopEvent, ToolResultEvent, TypedEvent from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue @@ -480,8 +480,10 @@ async def structured_output_async( events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if "stop" not in event: - typed_event = StreamDeltaEvent(delta_data=event) - typed_event.prepare_and_invoke(invocation_state={}, callback_handler=self.callback_handler) + typed_event = ModelStreamEvent(delta_data=event) + typed_event._prepare_and_invoke( + agent=self, invocation_state={}, callback_handler=self.callback_handler + ) structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} @@ -533,21 +535,15 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A invocation_state = kwargs events = self._run_loop(message, invocation_state=invocation_state) async for event in events: - if not isinstance(event, StopEvent): - event.prepare_and_invoke(invocation_state, callback_handler) - yield event + event._prepare_and_invoke( + agent=self, invocation_state=invocation_state, callback_handler=callback_handler + ) + yield event if not isinstance(event, StopEvent): raise ValueError(f"Last event was not a StopEvent; was: {event}") - event = cast(StopEvent, event) - - result = event.result - result_event = ResultEvent(result=result) - result_event.prepare_and_invoke(invocation_state, callback_handler) - yield result_event - - self._end_agent_trace_span(response=result) + self._end_agent_trace_span(response=event.result) except Exception as e: self._end_agent_trace_span(error=e) @@ -576,7 +572,7 @@ async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. if ( - isinstance(event, StreamDeltaEvent) + isinstance(event, ModelStreamEvent) and event.delta_data.get("event") and event.delta_data["event"].get("redactContent") and event.delta_data["event"]["redactContent"].get("redactUserContentMessage") @@ -592,7 +588,7 @@ async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index d51c70b8a..df65f44e8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -32,10 +32,10 @@ EventLoopThrottleDelay, ForceStopEvent, MessageEvent, + ModelStreamEvent, StartEvent, StartEventLoopEvent, StopEvent, - StreamDeltaEvent, ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, @@ -146,13 +146,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) try: - # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state - # before yielding to the callback handler. This will be revisited when migrating to strongly - # typed events. async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - # NEED TO DISCUSS - what events *shouldn't* be yield back? if "stop" not in event: - yield StreamDeltaEvent(delta_data=event) + yield ModelStreamEvent(delta_data=event) stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -248,13 +244,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> invocation_state=invocation_state, ) async for event in events: - print("evloop", event) # Note: Once we allow streaming of tool results, we need to do a check here for it if isinstance(event, TypedEvent): yield event else: - logger.warning(f"event=<{event} | non-typed event") - raise ValueError("Event was not a subclass of TypedEvent") + raise ValueError(f"Event was not a subclass of TypedEvent: {event}") return # End the cycle and return results diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 749537627..13a3b1ff6 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model from ..types.content import ContentBlock, Message, Messages @@ -21,6 +21,9 @@ ) from ..types.tools import ToolSpec, ToolUse +if TYPE_CHECKING: + from ..types.events import ModelStreamEvent + logger = logging.getLogger(__name__) @@ -103,7 +106,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: def handle_content_block_delta( event: ContentBlockDeltaEvent, state: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, Any], "ModelStreamEvent"]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: @@ -113,20 +116,28 @@ def handle_content_block_delta( Returns: Updated state with appended text or tool input. """ + from ..types.events import ( + ModelStreamEvent, + ReasoningSignatureStreamEvent, + ReasoningTextStreamEvent, + TextStreamEvent, + ToolUseStreamEvent, + ) + delta_content = event["delta"] - callback_event = {} + callback_event: ModelStreamEvent if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_event = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} + callback_event = ToolUseStreamEvent(event["delta"], state["current_tool_use"]) elif "text" in delta_content: state["text"] += delta_content["text"] - callback_event = {"data": delta_content["text"], "delta": delta_content} + callback_event = TextStreamEvent(event["delta"], delta_content["text"]) elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -134,22 +145,16 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_event = { - "reasoningText": delta_content["reasoningContent"]["text"], - "delta": delta_content, - "reasoning": True, - } + callback_event = ReasoningTextStreamEvent(event["delta"], delta_content["reasoningContent"]["text"]) elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event = { - "reasoning_signature": delta_content["reasoningContent"]["signature"], - "delta": delta_content, - "reasoning": True, - } + callback_event = ReasoningSignatureStreamEvent( + event["delta"], delta_content["reasoningContent"]["signature"] + ) return state, callback_event diff --git a/src/strands/types/events.py b/src/strands/types/events.py index 78a8229ea..2afebb983 100644 --- a/src/strands/types/events.py +++ b/src/strands/types/events.py @@ -1,30 +1,80 @@ +"""Event system for the Strands Agents framework. + +This module defines the event types that are emitted during agent execution, +providing a structured way to observe and respond to different stages of the +agent lifecycle including initialization, model invocation, tool execution, +and completion. +""" + from dataclasses import MISSING -from typing import Any, AsyncGenerator, Callable, Literal, override +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Literal + +from typing_extensions import override from ..telemetry import EventLoopMetrics from .content import Message from .event_loop import StopReason +from .streaming import ContentBlockDelta from .tools import ToolResult, ToolUse +if TYPE_CHECKING: + from ..agent import Agent, AgentResult + class TypedEvent(dict): - # Do I have to be careful about tool stream results? + """Base class for all typed events in the agent system. + + TypedEvent provides a framework for creating strongly-typed events that can be + processed by callback handlers. + """ + invocation_state: dict[str, Any] - def __init__(self, data: dict[str, Any] = None): + def __init__(self, data: dict[str, Any] = None) -> None: + """Initialize the typed event with optional data. + + Args: + data: Optional dictionary of event data to initialize with + """ super().__init__(data or {}) def _get_callback_fields(self) -> list[str] | Literal["all"] | None: + """When invoking a callback with this event, which fields should be exposed. + + This is for **backwards compatability only** and should not be implemented for new events. + + Can be an array of properties to pass to the callback handler, can be the literal "all" to + include all properties in the dict, or None to indicate that the callback should not be invoked. + """ return None - def set_invocation_state(self, invocation_state: dict): + def _set_invocation_state(self, invocation_state: dict) -> None: + """Sets the invocation state for this event instance. + + This is for **backwards compatability only** and should not be implemented for new events. + + Args: + invocation_state: Dictionary containing context and state information + for the current agent invocation + """ self.invocation_state = invocation_state - def prepare_and_invoke(self, invocation_state: dict, callback_handler: Callable): - self.set_invocation_state(invocation_state) + def _prepare_and_invoke(self, *, agent: "Agent", invocation_state: dict, callback_handler: Callable) -> bool: + """Internal API: Prepares the event and invokes the callback handler if applicable. + + This method sets the invocation state on the event and invokes the callback_handler with + the event if it is configured to be invoked - # invoking - args = TypedEvent.get_callback_fields(self) + Args: + agent: The agent instance (not currently used but will be in future versions) + invocation_state: Context and state information for the current invocation + callback_handler: The callback function to invoke with event data + + Returns: + bool: True if the callback was invoked, False otherwise + """ + self._set_invocation_state(invocation_state) + args = TypedEvent._get_callback_arguments(self) if args: callback_handler(**args) @@ -33,7 +83,7 @@ def prepare_and_invoke(self, invocation_state: dict, callback_handler: Callable) return False @staticmethod - def get_callback_fields(event: "TypedEvent") -> dict[str, Any] | None: + def _get_callback_arguments(event: "TypedEvent") -> dict[str, Any] | None: allow_listed = event._get_callback_fields() if allow_listed is None: @@ -45,176 +95,284 @@ def get_callback_fields(event: "TypedEvent") -> dict[str, Any] | None: return {k: v for k, v in event.items() if k in allow_listed} -class StartEventLoopEvent(TypedEvent): - def __init__(self): - super().__init__({"start_event_loop": True}) - - def _get_callback_fields(self): - return ["start_event_loop"] +class InitEventLoopEvent(TypedEvent): + """Event emitted at the very beginning of agent execution. + This event is fired before any processing begins and provides access to the + initial invocation state. + """ -class InitEventLoopEvent(TypedEvent): - def __init__(self): + def __init__(self) -> None: + """Initialize the event loop initialization event.""" super().__init__({"init_event_loop": True}) @override - def set_invocation_state(self, invocation_state: dict): - super().set_invocation_state(invocation_state) + def _set_invocation_state(self, invocation_state: dict) -> None: + super()._set_invocation_state(invocation_state) # For backwards compatability, make sure that we're merging the # invocation state as a readonly copy into ourselves self.update(**invocation_state) - def _get_callback_fields(self): + def _get_callback_fields(self) -> list[str]: return ["init_event_loop"] -class ForceStopEvent(TypedEvent): - @property - def reason(self) -> str: - return self.get("force_stop_reason") +class StartEvent(TypedEvent): + """Event emitted at the start of each event loop cycle. - @property - def reason_exception(self) -> Exception | None: - return self.get("force_stop_reason_exception") + ::deprecated:: + Use StartEventLoopEvent instead. - def __init__(self, reason: str | Exception): - super().__init__( - {"force_stop": True, "force_stop_reason": str(reason), "force_stop_reason_exception": reason or MISSING} - ) + This event signals the beginning of a new processing cycle within the agent's + event loop. It's fired before model invocation and tool execution begin. + """ - def _get_callback_fields(self): - return ["force_stop", "force_stop_reason"] + def __init__(self) -> None: + """Initialize the event loop start event.""" + super().__init__({"start": True}) + + def _get_callback_fields(self) -> list[str]: + return ["start"] + + +class StartEventLoopEvent(TypedEvent): + """Event emitted when the event loop cycle begins processing. + + This event is fired after StartEvent and indicates that the event loop + has begun its core processing logic, including model invocation preparation. + """ + + def __init__(self) -> None: + """Initialize the event loop processing start event.""" + super().__init__({"start_event_loop": True}) + + def _get_callback_fields(self) -> list[str]: + return ["start_event_loop"] class MessageEvent(TypedEvent): - def __init__(self, message: Message): + """Event emitted when a the model invocation has completed. + + This event is fired whenever the model generates a response message that + gets added to the conversation history. + """ + + def __init__(self, message: Message) -> None: + """Initialize with the model-generated message. + + Args: + message: The response message from the model + """ super().__init__({"message": message}) @property def message(self) -> Message: + """The model-generated response message.""" return self.get("message") - def _get_callback_fields(self): + def _get_callback_fields(self) -> list[str]: return ["message"] class EventLoopThrottleDelay(TypedEvent): - def __init__(self, delay: int): + """Event emitted when the event loop is throttled due to rate limiting.""" + + def __init__(self, delay: int) -> None: + """Initialize with the throttle delay duration. + + Args: + delay: Delay in seconds before the next retry attempt + """ super().__init__({"event_loop_throttled_delay": delay}) @property - def delay(self) -> Message: + def delay(self) -> int: + """Delay in seconds before the next retry attempt.""" return self.get("event_loop_throttled_delay") - def _get_callback_fields(self): + def _get_callback_fields(self) -> list[str]: return ["event_loop_throttled_delay"] +class ForceStopEvent(TypedEvent): + """Event emitted when the agent execution is forcibly stopped, either by a tool or by an exception. + + This event is fired when an unrecoverable error occurs during execution, + such as repeated throttling failures or critical system errors. It provides + the reason for the forced stop and any associated exception. + """ + + @property + def reason(self) -> str: + """Human-readable description of why execution was stopped.""" + return self.get("force_stop_reason") + + @property + def reason_exception(self) -> Exception | None: + """The original exception that caused the stop, if applicable.""" + return self.get("force_stop_reason_exception") + + def __init__(self, reason: str | Exception) -> None: + """Initialize with the reason for forced stop. + + Args: + reason: String description or exception that caused the forced stop + """ + super().__init__( + { + "force_stop": True, + "force_stop_reason": str(reason), + "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING, + } + ) + + def _get_callback_fields(self) -> list[str]: + return ["force_stop", "force_stop_reason"] + + class StopEvent(TypedEvent): + """Event emitted when the agent execution completes normally. + + This event is fired when the event loop cycle completes successfully, + providing the final result including the stop reason, final message, + execution metrics, and final state. It represents the successful + completion of an agent invocation. + """ + def __init__( self, stop_reason: StopReason, message: Message, metrics: "EventLoopMetrics", request_state: Any, - ): + ) -> None: + """Initialize with the final execution results. + + Args: + stop_reason: Why the agent execution stopped + message: Final message from the model + metrics: Execution metrics and performance data + request_state: Final state of the agent execution + """ + from strands.agent import AgentResult + super().__init__( { - "stop": (stop_reason, message, metrics, request_state), + "result": AgentResult( + stop_reason=stop_reason, + message=message, + metrics=metrics, + state=request_state, + ), } ) - @property - def stop_reason(self) -> StopReason: - return self.get("stop")[0] - - @property - def message(self) -> Message: - return self.get("stop")[1] - - @property - def metrics(self) -> "EventLoopMetrics": - return self.get("stop")[2] - - @property - def request_state(self) -> Any: - return self.get("stop")[3] - @property def result(self) -> "AgentResult": - from strands.agent import AgentResult - - return AgentResult( - stop_reason=self.stop_reason, - message=self.message, - metrics=self.metrics, - state=self.request_state, - ) + """Complete execution result with metrics and final state.""" + return self.get("result") - def _get_callback_fields(self): - return ["stop"] + @override + def _get_callback_fields(self) -> list[str]: + return ["result"] -class StartEvent(TypedEvent): - def __init__(self): - super().__init__({"start": True}) +class ToolStreamEvent(TypedEvent): + """Event emitted during tool execution streaming. - def _get_callback_fields(self): - return ["start"] + This event is fired when a tool produces streaming output during execution. + It provides access to the tool use information and the streaming data, + allowing callbacks to process tool execution progress in real-time. + """ + def __init__(self, tool_use: ToolUse, stream_data: Any) -> None: + """Initialize with tool streaming data. -class ToolStreamEvent(TypedEvent): - def __init__(self, tool_use: ToolUse, stream_data: any): + Args: + tool_use: The tool invocation producing the stream + stream_data: Incremental data from the tool execution + """ super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": stream_data}) @property def tool_use(self) -> ToolUse: + """The tool invocation that is producing streaming output.""" return self.get("tool_stream_tool_use") @property def tool_stream_data(self) -> Any: + """Incremental data chunk from the streaming tool execution.""" return self.get("tool_stream_data") class ToolResultEvent(TypedEvent): - def __init__(self, tool_result: ToolResult): + """Event emitted when a tool execution completes.""" + + def __init__(self, tool_result: ToolResult) -> None: + """Initialize with the completed tool result. + + Args: + tool_result: Final result from the tool execution + """ super().__init__({"tool_result": tool_result}) @property def tool_result(self) -> ToolResult: + """Final result from the completed tool execution.""" return self.get("tool_result") class ToolResultMessageEvent(TypedEvent): - def __init__(self, message: Any): + """Event emitted when tool results are formatted as a message. + + This event is fired when tool execution results are converted into a + message format to be added to the conversation history. It provides + access to the formatted message containing tool results. + """ + + def __init__(self, message: Any) -> None: + """Initialize with the formatted tool result message. + + Args: + message: Message containing tool results for conversation history + """ super().__init__({"message": message}) @property def message(self) -> Any: + """Message containing formatted tool results for conversation history.""" return self.get("message") @override - def _get_callback_fields(self): + def _get_callback_fields(self) -> list[str]: return ["message"] -class StreamDeltaEvent(TypedEvent): - def __init__(self, delta_data: dict[str, Any]): +class ModelStreamEvent(TypedEvent): + """Event emitted during model response streaming. + + This event is fired when the model produces streaming output during response + generation. It provides access to incremental delta data from the model, + allowing callbacks to process and display streaming responses in real-time. + """ + + def __init__(self, delta_data: dict[str, Any]) -> None: + """Initialize with streaming delta data from the model. + + Args: + delta_data: Incremental streaming data from the model response + """ super().__init__(delta_data) @property def delta_data(self) -> Any: + """Streaming data from the model response.""" return self - def include_state(self): - return True - - def as_callback(self) -> dict[str, Any]: - return self.delta_data - - def set_invocation_state(self, invocation_state: dict): - super().set_invocation_state(invocation_state) + @override + def _set_invocation_state(self, invocation_state: dict) -> None: + super()._set_invocation_state(invocation_state) # For backwards compatability, make sure that we're merging the # invocation state as a readonly copy into ourselves @@ -222,22 +380,100 @@ def set_invocation_state(self, invocation_state: dict): self.update(**invocation_state) @override - def _get_callback_fields(self): + def _get_callback_fields(self) -> Any: return Any -class ResultEvent(TypedEvent): - def __init__(self, result: "AgentResult"): - super().__init__({"result": result}) +class ToolUseStreamEvent(ModelStreamEvent): + """Event emitted during tool use input streaming.""" + + def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: + """Initialize with delta and current tool use state.""" + super().__init__({"delta": delta, "current_tool_use": current_tool_use}) @property - def result(self) -> "AgentResult": - return self.get("result") + def delta(self) -> ContentBlockDelta: + """Delta data from the model response.""" + return self.get("delta") - @override - def _get_callback_fields(self): - return ["result"] + @property + def current_tool_use(self) -> dict[str, Any]: + """Current tool use state.""" + return self.get("current_tool_use") + + +class TextStreamEvent(ModelStreamEvent): + """Event emitted during text content streaming.""" + + def __init__(self, delta: ContentBlockDelta, text: str) -> None: + """Initialize with delta and text content.""" + super().__init__({"data": text, "delta": delta}) + + @property + def text(self) -> str: + """Cumulative text content assembled from streaming deltas received thus far.""" + return self.get("data") + @property + def delta(self) -> ContentBlockDelta: + """Delta data from the model response.""" + return self.get("delta") + + +class ReasoningTextStreamEvent(ModelStreamEvent): + """Event emitted during reasoning text streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_text: str) -> None: + """Initialize with delta and reasoning text.""" + super().__init__({"reasoningText": reasoning_text, "delta": delta, "reasoning": True}) + + @property + def reasoningText(self) -> str: + """Cumulative reasoning text content assembled from streaming deltas received thus far.""" + return self.get("reasoningText") + + @property + def delta(self) -> ContentBlockDelta: + """Delta data from the model response.""" + return self.get("delta") + + +class ReasoningSignatureStreamEvent(ModelStreamEvent): + """Event emitted during reasoning signature streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_signature: str) -> None: + """Initialize with delta and reasoning signature.""" + super().__init__({"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}) + + @property + def reasoning_signature(self) -> str: + """Cumulative reasoning signature content assembled from streaming deltas received thus far.""" + return self.get("reasoning_signature") + + @property + def delta(self) -> ContentBlockDelta: + """Delta data from the model response.""" + return self.get("delta") + + +AllTypedEvents = ( + InitEventLoopEvent + | StartEvent + | StartEventLoopEvent + | MessageEvent + | EventLoopThrottleDelay + | ForceStopEvent + | StopEvent + | ToolStreamEvent + | ToolResultEvent + | ToolResultMessageEvent + | ModelStreamEvent + | ToolUseStreamEvent + | TextStreamEvent + | ReasoningTextStreamEvent + | ReasoningSignatureStreamEvent + | TypedEvent +) TypedToolGenerator = AsyncGenerator[TypedEvent, None] """Generator of tool events where all events are typed as TypedEvents.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 125bbe718..fc310cb4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -22,9 +22,8 @@ from strands.types.content import Messages from strands.types.events import ( InitEventLoopEvent, - ResultEvent, + ModelStreamEvent, StopEvent, - StreamDeltaEvent, ToolResultEvent, ) from strands.types.exceptions import ContextWindowOverflowException, EventLoopException @@ -1166,9 +1165,9 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): # Define the side effect to simulate callback handler being called multiple times async def test_event_loop(*args, **kwargs): - yield StreamDeltaEvent(delta_data={"data": "First chunk"}) - yield StreamDeltaEvent(delta_data={"data": "Second chunk"}) - yield StreamDeltaEvent(delta_data={"data": "Final chunk", "complete": True}) + yield ModelStreamEvent(delta_data={"data": "First chunk"}) + yield ModelStreamEvent(delta_data={"data": "Second chunk"}) + yield ModelStreamEvent(delta_data={"data": "Final chunk", "complete": True}) # Return expected values from event_loop_cycle yield StopEvent( @@ -1190,16 +1189,14 @@ async def test_event_loop(*args, **kwargs): tru_events = await alist(stream) exp_events = [ init_event, - StreamDeltaEvent(delta_data={"data": "First chunk"}), - StreamDeltaEvent(delta_data={"data": "Second chunk"}), - StreamDeltaEvent(delta_data={"data": "Final chunk", "complete": True}), - ResultEvent( - result=AgentResult( - stop_reason="stop", - message={"role": "assistant", "content": [{"text": "Response"}]}, - metrics={}, - state={}, - ) + ModelStreamEvent(delta_data={"data": "First chunk"}), + ModelStreamEvent(delta_data={"data": "Second chunk"}), + ModelStreamEvent(delta_data={"data": "Final chunk", "complete": True}), + StopEvent( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + request_state={}, ), ] assert tru_events == exp_events @@ -1278,7 +1275,7 @@ def fake_tool(agent: Agent): ] assert tru_events == exp_events - exp_calls = [unittest.mock.call(**event) for event in exp_events] + # exp_calls = [unittest.mock.call(**event) for event in exp_events] # mock_callback.assert_has_calls(exp_calls) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 20f9895d9..cbf621cdb 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,6 +6,7 @@ import strands import strands.telemetry +from strands.agent import AgentResult from strands.event_loop.event_loop import run_tool from strands.experimental.hooks import ( AfterModelInvocationEvent, @@ -173,13 +174,12 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] - - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} + tru_result = events[-1]["result"] + exp_result = AgentResult( + stop_reason="end_turn", message={"role": "assistant", "content": [{"text": "test text"}]}, metrics=ANY, state={} + ) - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + assert tru_result == exp_result @pytest.mark.asyncio @@ -205,13 +205,13 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} + tru_result = events[-1]["result"] + exp_result = AgentResult( + stop_reason="end_turn", message={"role": "assistant", "content": [{"text": "test text"}]}, metrics=ANY, state={} + ) - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + assert tru_result == exp_result # Verify that sleep was called once with the initial delay mock_time.sleep.assert_called_once() @@ -243,12 +243,12 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_result = events[-1]["result"] + exp_result = AgentResult( + stop_reason="end_turn", message={"role": "assistant", "content": [{"text": "test text"}]}, metrics=ANY, state={} + ) - # Verify the final response - assert tru_stop_reason == "end_turn" - assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} - assert tru_request_state == {} + assert tru_result == exp_result # Verify that sleep was called with increasing delays # Initial delay is 4, then 8, then 16 @@ -334,13 +334,13 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} + tru_result = events[-1]["result"] + exp_result = AgentResult( + stop_reason="end_turn", message={"role": "assistant", "content": [{"text": "test text"}]}, metrics=ANY, state={} + ) - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + assert tru_result == exp_result # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason mock_recover_message.assert_not_called() @@ -449,24 +449,27 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] - exp_stop_reason = "tool_use" - exp_message = { - "role": "assistant", - "content": [ - { - "toolUse": { - "input": {}, - "name": "tool_for_testing", - "toolUseId": "t1", + tru_result = events[-1]["result"] + exp_result = AgentResult( + stop_reason="tool_use", + message={ + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {}, + "name": "tool_for_testing", + "toolUseId": "t1", + } } - } - ], - } - exp_request_state = {"stop_event_loop": True} + ], + }, + metrics=ANY, + state={"stop_event_loop": True}, + ) - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + assert tru_result == exp_result @pytest.mark.asyncio @@ -751,7 +754,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + tru_request_state = events[-1]["result"].state # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -763,7 +766,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + tru_request_state = events[-1]["result"].state # Verify existing request_state was preserved assert tru_request_state == initial_request_state From f9f3603b32c088b9e480d87a78fc278c6a6692cb Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 20 Aug 2025 18:45:33 -0400 Subject: [PATCH 6/7] Fix typing errors --- src/strands/agent/agent.py | 2 +- src/strands/event_loop/streaming.py | 2 +- src/strands/types/events.py | 38 ++++++++++++++--------------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9a2635259..dae9e03e4 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -158,7 +158,7 @@ def tcall() -> ToolResultEvent: if should_record_direct_tool_call: # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, last_event, user_message_override) + self._agent._record_tool_execution(tool_use, last_event.tool_result, user_message_override) # Apply window management self._agent.conversation_manager.apply_management(self._agent) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 13a3b1ff6..65d7cc738 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -126,7 +126,7 @@ def handle_content_block_delta( delta_content = event["delta"] - callback_event: ModelStreamEvent + callback_event = ModelStreamEvent({}) if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: diff --git a/src/strands/types/events.py b/src/strands/types/events.py index 2afebb983..5d547d132 100644 --- a/src/strands/types/events.py +++ b/src/strands/types/events.py @@ -30,7 +30,7 @@ class TypedEvent(dict): invocation_state: dict[str, Any] - def __init__(self, data: dict[str, Any] = None) -> None: + def __init__(self, data: dict[str, Any] | None = None) -> None: """Initialize the typed event with optional data. Args: @@ -152,7 +152,7 @@ def _get_callback_fields(self) -> list[str]: class MessageEvent(TypedEvent): - """Event emitted when a the model invocation has completed. + """Event emitted when the model invocation has completed. This event is fired whenever the model generates a response message that gets added to the conversation history. @@ -169,7 +169,7 @@ def __init__(self, message: Message) -> None: @property def message(self) -> Message: """The model-generated response message.""" - return self.get("message") + return self.get("message") # type: ignore[return-value] def _get_callback_fields(self) -> list[str]: return ["message"] @@ -189,7 +189,7 @@ def __init__(self, delay: int) -> None: @property def delay(self) -> int: """Delay in seconds before the next retry attempt.""" - return self.get("event_loop_throttled_delay") + return self.get("event_loop_throttled_delay") # type: ignore[return-value] def _get_callback_fields(self) -> list[str]: return ["event_loop_throttled_delay"] @@ -206,7 +206,7 @@ class ForceStopEvent(TypedEvent): @property def reason(self) -> str: """Human-readable description of why execution was stopped.""" - return self.get("force_stop_reason") + return self.get("force_stop_reason") # type: ignore[return-value] @property def reason_exception(self) -> Exception | None: @@ -255,7 +255,7 @@ def __init__( metrics: Execution metrics and performance data request_state: Final state of the agent execution """ - from strands.agent import AgentResult + from ..agent import AgentResult super().__init__( { @@ -271,7 +271,7 @@ def __init__( @property def result(self) -> "AgentResult": """Complete execution result with metrics and final state.""" - return self.get("result") + return self.get("result") # type: ignore[return-value] @override def _get_callback_fields(self) -> list[str]: @@ -298,7 +298,7 @@ def __init__(self, tool_use: ToolUse, stream_data: Any) -> None: @property def tool_use(self) -> ToolUse: """The tool invocation that is producing streaming output.""" - return self.get("tool_stream_tool_use") + return self.get("tool_stream_tool_use") # type: ignore[return-value] @property def tool_stream_data(self) -> Any: @@ -320,7 +320,7 @@ def __init__(self, tool_result: ToolResult) -> None: @property def tool_result(self) -> ToolResult: """Final result from the completed tool execution.""" - return self.get("tool_result") + return self.get("tool_result") # type: ignore[return-value] class ToolResultMessageEvent(TypedEvent): @@ -394,12 +394,12 @@ def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) - @property def delta(self) -> ContentBlockDelta: """Delta data from the model response.""" - return self.get("delta") + return self.get("delta") # type: ignore[return-value] @property def current_tool_use(self) -> dict[str, Any]: """Current tool use state.""" - return self.get("current_tool_use") + return self.get("current_tool_use") # type: ignore[return-value] class TextStreamEvent(ModelStreamEvent): @@ -412,48 +412,48 @@ def __init__(self, delta: ContentBlockDelta, text: str) -> None: @property def text(self) -> str: """Cumulative text content assembled from streaming deltas received thus far.""" - return self.get("data") + return self.get("data") # type: ignore[return-value] @property def delta(self) -> ContentBlockDelta: """Delta data from the model response.""" - return self.get("delta") + return self.get("delta") # type: ignore[return-value] class ReasoningTextStreamEvent(ModelStreamEvent): """Event emitted during reasoning text streaming.""" - def __init__(self, delta: ContentBlockDelta, reasoning_text: str) -> None: + def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: """Initialize with delta and reasoning text.""" super().__init__({"reasoningText": reasoning_text, "delta": delta, "reasoning": True}) @property def reasoningText(self) -> str: """Cumulative reasoning text content assembled from streaming deltas received thus far.""" - return self.get("reasoningText") + return self.get("reasoningText") # type: ignore[return-value] @property def delta(self) -> ContentBlockDelta: """Delta data from the model response.""" - return self.get("delta") + return self.get("delta") # type: ignore[return-value] class ReasoningSignatureStreamEvent(ModelStreamEvent): """Event emitted during reasoning signature streaming.""" - def __init__(self, delta: ContentBlockDelta, reasoning_signature: str) -> None: + def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: """Initialize with delta and reasoning signature.""" super().__init__({"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}) @property def reasoning_signature(self) -> str: """Cumulative reasoning signature content assembled from streaming deltas received thus far.""" - return self.get("reasoning_signature") + return self.get("reasoning_signature") # type: ignore[return-value] @property def delta(self) -> ContentBlockDelta: """Delta data from the model response.""" - return self.get("delta") + return self.get("delta") # type: ignore[return-value] AllTypedEvents = ( From 4982332d26c43a8236930305961407ab84372e8a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Wed, 20 Aug 2025 19:29:48 -0400 Subject: [PATCH 7/7] Switch to using Signal terminology instead of Event terminology --- src/strands/agent/agent.py | 90 +++++----- src/strands/event_loop/event_loop.py | 120 ++++++------- src/strands/event_loop/streaming.py | 38 ++-- src/strands/tools/executor.py | 20 +-- src/strands/types/{events.py => signals.py} | 182 +++++++++----------- tests/strands/agent/test_agent.py | 40 ++--- tests/strands/event_loop/test_event_loop.py | 18 +- tests/strands/tools/test_executor.py | 10 +- 8 files changed, 250 insertions(+), 268 deletions(-) rename src/strands/types/{events.py => signals.py} (70%) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index dae9e03e4..3f2e796d0 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,8 +37,8 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.events import InitEventLoopEvent, ModelStreamEvent, StopEvent, ToolResultEvent, TypedEvent from ..types.exceptions import ContextWindowOverflowException +from ..types.signals import AgentSignal, InitSignalLoopSignal, ModelStreamSignal, StopSignal, ToolResultSignal from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -137,19 +137,19 @@ def caller( "input": kwargs.copy(), } - async def acall() -> ToolResultEvent: + async def acall() -> ToolResultSignal: # Pass kwargs as invocation_state - async for event in run_tool(self._agent, tool_use, kwargs): - _ = event + async for signal in run_tool(self._agent, tool_use, kwargs): + _ = signal - return cast(ToolResultEvent, event) + return cast(ToolResultSignal, signal) - def tcall() -> ToolResultEvent: + def tcall() -> ToolResultSignal: return asyncio.run(acall()) with ThreadPoolExecutor() as executor: future = executor.submit(tcall) - last_event = future.result() + last_signal = future.result() if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -158,12 +158,12 @@ def tcall() -> ToolResultEvent: if should_record_direct_tool_call: # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, last_event.tool_result, user_message_override) + self._agent._record_tool_execution(tool_use, last_signal.tool_result, user_message_override) # Apply window management self._agent.conversation_manager.apply_management(self._agent) - return last_event.tool_result + return last_signal.tool_result return caller @@ -395,11 +395,11 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, **kwargs) - async for event in events: - _ = event + signals = self.stream_async(prompt, **kwargs) + async for signal in signals: + _ = signal - return cast(AgentResult, event["result"]) + return cast(AgentResult, signal["result"]) def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: """This method allows you to get structured output from the agent. @@ -477,23 +477,23 @@ async def structured_output_async( "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, ) - events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) - async for event in events: - if "stop" not in event: - typed_event = ModelStreamEvent(delta_data=event) - typed_event._prepare_and_invoke( + signals = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) + async for signal in signals: + if "stop" not in signal: + typed_signal = ModelStreamSignal(delta_data=signal) + typed_signal._prepare_and_invoke( agent=self, invocation_state={}, callback_handler=self.callback_handler ) structured_output_span.add_event( - "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} + "gen_ai.choice", attributes={"message": serialize(signal["output"].model_dump())} ) - return event["output"] + return signal["output"] finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: + async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[AgentSignal]: """Process a natural language prompt and yield events as an async iterator. This method provides an asynchronous interface for streaming agent events, allowing @@ -533,23 +533,23 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A with trace_api.use_span(self.trace_span): try: invocation_state = kwargs - events = self._run_loop(message, invocation_state=invocation_state) - async for event in events: - event._prepare_and_invoke( + signals = self._run_loop(message, invocation_state=invocation_state) + async for signal in signals: + signal._prepare_and_invoke( agent=self, invocation_state=invocation_state, callback_handler=callback_handler ) - yield event + yield signal - if not isinstance(event, StopEvent): - raise ValueError(f"Last event was not a StopEvent; was: {event}") + if not isinstance(signal, StopSignal): + raise ValueError(f"Last signal was not a StopSignal; was: {signal}") - self._end_agent_trace_span(response=event.result) + self._end_agent_trace_span(response=signal.result) except Exception as e: self._end_agent_trace_span(error=e) raise - async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> AsyncGenerator[AgentSignal, None]: """Execute the agent's event loop with the given message and parameters. Args: @@ -562,33 +562,33 @@ async def _run_loop(self, message: Message, invocation_state: dict[str, Any]) -> self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield InitEventLoopEvent() + yield InitSignalLoopSignal() self._append_message(message) # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: + signals = self._execute_event_loop_cycle(invocation_state) + async for signal in signals: # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. if ( - isinstance(event, ModelStreamEvent) - and event.delta_data.get("event") - and event.delta_data["event"].get("redactContent") - and event.delta_data["event"]["redactContent"].get("redactUserContentMessage") + isinstance(signal, ModelStreamSignal) + and signal.delta_data.get("event") + and signal.delta_data["event"].get("redactContent") + and signal.delta_data["event"]["redactContent"].get("redactUserContentMessage") ): self.messages[-1]["content"] = [ - {"text": event.delta_data["event"]["redactContent"]["redactUserContentMessage"]} + {"text": signal.delta_data["event"]["redactContent"]["redactUserContentMessage"]} ] if self._session_manager: self._session_manager.redact_latest_message(self.messages[-1], self) - yield event + yield signal finally: self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[AgentSignal, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements @@ -603,12 +603,12 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A try: # Execute the main event loop cycle - events = event_loop_cycle( + signals = event_loop_cycle( agent=self, invocation_state=invocation_state, ) - async for event in events: - yield event + async for signal in signals: + yield signal except ContextWindowOverflowException as e: # Try reducing the context size and retrying @@ -618,9 +618,9 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A if self._session_manager: self._session_manager.sync_agent(self) - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event + signals = self._execute_event_loop_cycle(invocation_state) + async for signal in signals: + yield signal def _record_tool_execution( self, diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index df65f44e8..adbbc0c82 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,26 +28,26 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.events import ( - EventLoopThrottleDelay, - ForceStopEvent, - MessageEvent, - ModelStreamEvent, - StartEvent, - StartEventLoopEvent, - StopEvent, - ToolResultEvent, - ToolResultMessageEvent, - ToolStreamEvent, - TypedEvent, - TypedToolGenerator, -) from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, ModelThrottledException, ) +from ..types.signals import ( + EventLoopThrottleSignal, + ForceStopSignal, + MessageSignal, + ModelStreamSignal, + SignalToolGenerator, + StartEventLoopSignal, + StartSignal, + StopSignal, + ToolResultMessageSignal, + ToolResultSignal, + ToolStreamSignal, + TypedSignal, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached @@ -63,7 +63,7 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedSignal, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -107,8 +107,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace - yield StartEvent() - yield StartEventLoopEvent() + yield StartSignal() + yield StartEventLoopSignal() # Create tracer span for this event loop cycle tracer = get_tracer() @@ -146,11 +146,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) try: - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if "stop" not in event: - yield ModelStreamEvent(delta_data=event) + async for signal in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + if "stop" not in signal: + yield ModelStreamSignal(delta_data=signal) - stop_reason, message, usage, metrics = event["stop"] + stop_reason, message, usage, metrics = signal["stop"] invocation_state.setdefault("request_state", {}) agent.hooks.invoke_callbacks( @@ -183,7 +183,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if isinstance(e, ModelThrottledException): if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(e) + yield ForceStopSignal(e) raise e logger.debug( @@ -197,7 +197,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield EventLoopThrottleDelay(delay=current_delay) + yield EventLoopThrottleSignal(delay=current_delay) else: raise e @@ -209,7 +209,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield MessageEvent(message=message) + yield MessageSignal(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -234,7 +234,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # If the model is requesting to use tools if stop_reason == "tool_use": # Handle tool execution - events = _handle_tool_execution( + signals = _handle_tool_execution( stop_reason, message, agent=agent, @@ -243,12 +243,12 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time=cycle_start_time, invocation_state=invocation_state, ) - async for event in events: + async for signal in signals: # Note: Once we allow streaming of tool results, we need to do a check here for it - if isinstance(event, TypedEvent): - yield event + if isinstance(signal, TypedSignal): + yield signal else: - raise ValueError(f"Event was not a subclass of TypedEvent: {event}") + raise ValueError(f"Signal was not a subclass of TypedSignal: {signal}") return # End the cycle and return results @@ -275,11 +275,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions - yield ForceStopEvent(reason=str(e)) + yield ForceStopSignal(reason=str(e)) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield StopEvent( + yield StopSignal( stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, @@ -287,7 +287,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedSignal, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -311,16 +311,16 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) cycle_trace.add_child(recursive_trace) - yield StartEvent() + yield StartSignal() - events = event_loop_cycle(agent=agent, invocation_state=invocation_state) - async for event in events: - yield event + signals = event_loop_cycle(agent=agent, invocation_state=invocation_state) + async for signal in signals: + yield signal recursive_trace.end() -async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> TypedToolGenerator: +async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> SignalToolGenerator: """Process a tool invocation. Looks up the tool in the registry and streams it with the provided parameters. @@ -397,17 +397,17 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str result=result, ) ) - yield ToolResultEvent(tool_result=after_event.result) + yield ToolResultSignal(tool_result=after_event.result) return - async for event in selected_tool.stream(tool_use, invocation_state): - # wrap all events in a ToolStreamEvent - print(event) - yield ToolStreamEvent(tool_use, event) + async for signal in selected_tool.stream(tool_use, invocation_state): + # wrap all signals in a ToolStreamSignal + print(signal) + yield ToolStreamSignal(tool_use, signal) - if not ("content" in event and "status" in event and "toolUseId" in event): - raise ValueError("Last event of tool invocation was not a tool result") - result = event + if not ("content" in signal and "status" in signal and "toolUseId" in signal): + raise ValueError("Last signal of tool invocation was not a tool result") + result = signal after_event = agent.hooks.invoke_callbacks( AfterToolInvocationEvent( @@ -419,7 +419,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str ) ) - yield ToolResultEvent(tool_result=after_event.result) + yield ToolResultSignal(tool_result=after_event.result) except Exception as e: logger.exception("tool_name=<%s> | failed to process tool", tool_name) @@ -438,7 +438,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str exception=e, ) ) - yield ToolResultEvent(tool_result=after_event.result) + yield ToolResultSignal(tool_result=after_event.result) async def _handle_tool_execution( @@ -449,7 +449,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], -) -> AsyncGenerator[TypedEvent, None]: +) -> AsyncGenerator[TypedSignal, None]: tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] @@ -478,7 +478,7 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) if not tool_uses: - yield StopEvent( + yield StopSignal( stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, @@ -486,10 +486,10 @@ async def _handle_tool_execution( ) return - def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: + def tool_handler(tool_use: ToolUse) -> SignalToolGenerator: return run_tool(agent, tool_use, invocation_state) - tool_events = run_tools( + tool_signals = run_tools( handler=tool_handler, tool_uses=tool_uses, event_loop_metrics=agent.event_loop_metrics, @@ -498,12 +498,12 @@ def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: cycle_trace=cycle_trace, parent_span=cycle_span, ) - async for tool_event in tool_events: - # For now, until we implement #543, supress tool stream events - if isinstance(tool_event, ToolStreamEvent): + async for tool_signal in tool_signals: + # For now, until we implement #543, supress tool stream signals + if isinstance(tool_signal, ToolStreamSignal): continue - yield tool_event + yield tool_signal # Store parent cycle ID for the next cycle invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] @@ -515,7 +515,7 @@ def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield ToolResultMessageEvent(message=tool_result_message) + yield ToolResultMessageSignal(message=tool_result_message) if cycle_span: tracer = get_tracer() @@ -523,7 +523,7 @@ def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: if invocation_state["request_state"].get("stop_event_loop", False): agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield StopEvent( + yield StopSignal( stop_reason=stop_reason, message=message, metrics=agent.event_loop_metrics, @@ -531,6 +531,6 @@ def tool_handler(tool_use: ToolUse) -> TypedToolGenerator: ) return - events = recurse_event_loop(agent=agent, invocation_state=invocation_state) - async for event in events: - yield event + signals = recurse_event_loop(agent=agent, invocation_state=invocation_state) + async for signal in signals: + yield signal diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 65d7cc738..cac826666 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -22,7 +22,7 @@ from ..types.tools import ToolSpec, ToolUse if TYPE_CHECKING: - from ..types.events import ModelStreamEvent + from ..types.signals import ModelStreamSignal logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: def handle_content_block_delta( event: ContentBlockDeltaEvent, state: dict[str, Any] -) -> tuple[dict[str, Any], "ModelStreamEvent"]: +) -> tuple[dict[str, Any], "ModelStreamSignal"]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: @@ -116,28 +116,28 @@ def handle_content_block_delta( Returns: Updated state with appended text or tool input. """ - from ..types.events import ( - ModelStreamEvent, - ReasoningSignatureStreamEvent, - ReasoningTextStreamEvent, - TextStreamEvent, - ToolUseStreamEvent, + from ..types.signals import ( + ModelStreamSignal, + ReasoningSignatureStreamSignal, + ReasoningTextStreamSignal, + TextStreamSignal, + ToolUseStreamSignal, ) delta_content = event["delta"] - callback_event = ModelStreamEvent({}) + signal = ModelStreamSignal({}) if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_event = ToolUseStreamEvent(event["delta"], state["current_tool_use"]) + signal = ToolUseStreamSignal(event["delta"], state["current_tool_use"]) elif "text" in delta_content: state["text"] += delta_content["text"] - callback_event = TextStreamEvent(event["delta"], delta_content["text"]) + signal = TextStreamSignal(event["delta"], delta_content["text"]) elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -145,18 +145,16 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_event = ReasoningTextStreamEvent(event["delta"], delta_content["reasoningContent"]["text"]) + signal = ReasoningTextStreamSignal(event["delta"], delta_content["reasoningContent"]["text"]) elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event = ReasoningSignatureStreamEvent( - event["delta"], delta_content["reasoningContent"]["signature"] - ) + signal = ReasoningSignatureStreamSignal(event["delta"], delta_content["reasoningContent"]["signature"]) - return state, callback_event + return state, signal def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: @@ -283,8 +281,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) - yield callback_event + state, callback_signal = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield callback_signal elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -320,5 +318,5 @@ async def stream_messages( chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) - async for event in process_stream(chunks): - yield event + async for signal in process_stream(chunks): + yield signal diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 1c3fc990e..47c4ccb8d 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -11,7 +11,7 @@ from ..telemetry.tracer import get_tracer from ..tools.tools import InvalidToolUseNameException, validate_tool_use from ..types.content import Message -from ..types.events import ToolResultEvent, TypedToolGenerator +from ..types.signals import SignalToolGenerator, ToolResultSignal from ..types.tools import RunToolHandler, ToolResult, ToolUse logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ async def run_tools( tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace_api.Span] = None, -) -> TypedToolGenerator: +) -> SignalToolGenerator: """Execute tools concurrently. Args: @@ -56,17 +56,17 @@ async def work( tool_start_time = time.time() with trace_api.use_span(tool_call_span): try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) + async for signal in handler(tool_use): + worker_queue.put_nowait((worker_id, signal)) await worker_event.wait() worker_event.clear() - result_event = cast(ToolResultEvent, event) + result_signal = cast(ToolResultSignal, signal) finally: worker_queue.put_nowait((worker_id, stop_event)) - tool_success = result_event.get("status") == "success" - tool_result = result_event.tool_result + tool_result = result_signal.tool_result + tool_success = tool_result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": tool_result}]) event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) @@ -88,12 +88,12 @@ async def work( worker_count = len(workers) while worker_count: - worker_id, event = await worker_queue.get() - if event is stop_event: + worker_id, signal = await worker_queue.get() + if signal is stop_event: worker_count -= 1 continue - yield event + yield signal worker_events[worker_id].set() tool_results.extend([worker.result() for worker in workers]) diff --git a/src/strands/types/events.py b/src/strands/types/signals.py similarity index 70% rename from src/strands/types/events.py rename to src/strands/types/signals.py index 5d547d132..7489ef023 100644 --- a/src/strands/types/events.py +++ b/src/strands/types/signals.py @@ -1,9 +1,8 @@ -"""Event system for the Strands Agents framework. +"""Signal system for the Strands Agents framework. -This module defines the event types that are emitted during agent execution, -providing a structured way to observe and respond to different stages of the -agent lifecycle including initialization, model invocation, tool execution, -and completion. +This module defines the signal types that are emitted during agent execution, +providing a structured way to observe to different events of the event loop and +agent lifecycle. """ from dataclasses import MISSING @@ -21,27 +20,23 @@ from ..agent import Agent, AgentResult -class TypedEvent(dict): - """Base class for all typed events in the agent system. - - TypedEvent provides a framework for creating strongly-typed events that can be - processed by callback handlers. - """ +class TypedSignal(dict): + """Base class for all typed signals in the agent system.""" invocation_state: dict[str, Any] def __init__(self, data: dict[str, Any] | None = None) -> None: - """Initialize the typed event with optional data. + """Initialize the typed signal with optional data. Args: - data: Optional dictionary of event data to initialize with + data: Optional dictionary of signal data to initialize with """ super().__init__(data or {}) def _get_callback_fields(self) -> list[str] | Literal["all"] | None: - """When invoking a callback with this event, which fields should be exposed. + """When invoking a callback with this signal, which fields should be exposed. - This is for **backwards compatability only** and should not be implemented for new events. + This is for **backwards compatability only** and should not be implemented for new signals. Can be an array of properties to pass to the callback handler, can be the literal "all" to include all properties in the dict, or None to indicate that the callback should not be invoked. @@ -49,9 +44,9 @@ def _get_callback_fields(self) -> list[str] | Literal["all"] | None: return None def _set_invocation_state(self, invocation_state: dict) -> None: - """Sets the invocation state for this event instance. + """Sets the invocation state for this signal instance. - This is for **backwards compatability only** and should not be implemented for new events. + This is for **backwards compatability only** and should not be implemented for new signals. Args: invocation_state: Dictionary containing context and state information @@ -60,21 +55,21 @@ def _set_invocation_state(self, invocation_state: dict) -> None: self.invocation_state = invocation_state def _prepare_and_invoke(self, *, agent: "Agent", invocation_state: dict, callback_handler: Callable) -> bool: - """Internal API: Prepares the event and invokes the callback handler if applicable. + """Internal API: Prepares the signal and invokes the callback handler if applicable. - This method sets the invocation state on the event and invokes the callback_handler with - the event if it is configured to be invoked + This method sets the invocation state on the signal and invokes the callback_handler with + the signal if it is configured to be invoked Args: agent: The agent instance (not currently used but will be in future versions) invocation_state: Context and state information for the current invocation - callback_handler: The callback function to invoke with event data + callback_handler: The callback function to invoke with signal data Returns: bool: True if the callback was invoked, False otherwise """ self._set_invocation_state(invocation_state) - args = TypedEvent._get_callback_arguments(self) + args = TypedSignal._get_callback_arguments(self) if args: callback_handler(**args) @@ -83,7 +78,7 @@ def _prepare_and_invoke(self, *, agent: "Agent", invocation_state: dict, callbac return False @staticmethod - def _get_callback_arguments(event: "TypedEvent") -> dict[str, Any] | None: + def _get_callback_arguments(event: "TypedSignal") -> dict[str, Any] | None: allow_listed = event._get_callback_fields() if allow_listed is None: @@ -95,15 +90,15 @@ def _get_callback_arguments(event: "TypedEvent") -> dict[str, Any] | None: return {k: v for k, v in event.items() if k in allow_listed} -class InitEventLoopEvent(TypedEvent): - """Event emitted at the very beginning of agent execution. +class InitSignalLoopSignal(TypedSignal): + """Signal emitted at the very beginning of agent execution. - This event is fired before any processing begins and provides access to the + This signal is fired before any processing begins and provides access to the initial invocation state. """ def __init__(self) -> None: - """Initialize the event loop initialization event.""" + """Initialize the event loop initialization signal.""" super().__init__({"init_event_loop": True}) @override @@ -118,43 +113,43 @@ def _get_callback_fields(self) -> list[str]: return ["init_event_loop"] -class StartEvent(TypedEvent): - """Event emitted at the start of each event loop cycle. +class StartSignal(TypedSignal): + """Signal emitted at the start of each event loop cycle. ::deprecated:: - Use StartEventLoopEvent instead. + Use StartSignalLoopSignal instead. - This event signals the beginning of a new processing cycle within the agent's + This signal signals the beginning of a new processing cycle within the agent's event loop. It's fired before model invocation and tool execution begin. """ def __init__(self) -> None: - """Initialize the event loop start event.""" + """Initialize the event loop start signal.""" super().__init__({"start": True}) def _get_callback_fields(self) -> list[str]: return ["start"] -class StartEventLoopEvent(TypedEvent): - """Event emitted when the event loop cycle begins processing. +class StartEventLoopSignal(TypedSignal): + """Signal emitted when the event loop cycle begins processing. - This event is fired after StartEvent and indicates that the event loop + This signal is fired after StartSignal and indicates that the event loop has begun its core processing logic, including model invocation preparation. """ def __init__(self) -> None: - """Initialize the event loop processing start event.""" + """Initialize the event loop processing start signal.""" super().__init__({"start_event_loop": True}) def _get_callback_fields(self) -> list[str]: return ["start_event_loop"] -class MessageEvent(TypedEvent): - """Event emitted when the model invocation has completed. +class MessageSignal(TypedSignal): + """Signal emitted when the model invocation has completed. - This event is fired whenever the model generates a response message that + This signal is fired whenever the model generates a response message that gets added to the conversation history. """ @@ -175,8 +170,8 @@ def _get_callback_fields(self) -> list[str]: return ["message"] -class EventLoopThrottleDelay(TypedEvent): - """Event emitted when the event loop is throttled due to rate limiting.""" +class EventLoopThrottleSignal(TypedSignal): + """Signal emitted when the event loop is throttled due to rate limiting.""" def __init__(self, delay: int) -> None: """Initialize with the throttle delay duration. @@ -195,10 +190,10 @@ def _get_callback_fields(self) -> list[str]: return ["event_loop_throttled_delay"] -class ForceStopEvent(TypedEvent): - """Event emitted when the agent execution is forcibly stopped, either by a tool or by an exception. +class ForceStopSignal(TypedSignal): + """Signal emitted when the agent execution is forcibly stopped, either by a tool or by an exception. - This event is fired when an unrecoverable error occurs during execution, + This signal is fired when an unrecoverable error occurs during execution, such as repeated throttling failures or critical system errors. It provides the reason for the forced stop and any associated exception. """ @@ -231,14 +226,8 @@ def _get_callback_fields(self) -> list[str]: return ["force_stop", "force_stop_reason"] -class StopEvent(TypedEvent): - """Event emitted when the agent execution completes normally. - - This event is fired when the event loop cycle completes successfully, - providing the final result including the stop reason, final message, - execution metrics, and final state. It represents the successful - completion of an agent invocation. - """ +class StopSignal(TypedSignal): + """Signal emitted when the agent execution completes normally.""" def __init__( self, @@ -278,22 +267,17 @@ def _get_callback_fields(self) -> list[str]: return ["result"] -class ToolStreamEvent(TypedEvent): - """Event emitted during tool execution streaming. - - This event is fired when a tool produces streaming output during execution. - It provides access to the tool use information and the streaming data, - allowing callbacks to process tool execution progress in real-time. - """ +class ToolStreamSignal(TypedSignal): + """Signal emitted when a tool yields signals as part of tool execution.""" - def __init__(self, tool_use: ToolUse, stream_data: Any) -> None: + def __init__(self, tool_use: ToolUse, signal: Any) -> None: """Initialize with tool streaming data. Args: tool_use: The tool invocation producing the stream - stream_data: Incremental data from the tool execution + signal: The yielded signal from the tool execution """ - super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": stream_data}) + super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_signal": signal}) @property def tool_use(self) -> ToolUse: @@ -302,12 +286,12 @@ def tool_use(self) -> ToolUse: @property def tool_stream_data(self) -> Any: - """Incremental data chunk from the streaming tool execution.""" - return self.get("tool_stream_data") + """The yielded signal from the tool execution.""" + return self.get("tool_stream_signal") -class ToolResultEvent(TypedEvent): - """Event emitted when a tool execution completes.""" +class ToolResultSignal(TypedSignal): + """Signal emitted when a tool execution completes.""" def __init__(self, tool_result: ToolResult) -> None: """Initialize with the completed tool result. @@ -323,10 +307,10 @@ def tool_result(self) -> ToolResult: return self.get("tool_result") # type: ignore[return-value] -class ToolResultMessageEvent(TypedEvent): - """Event emitted when tool results are formatted as a message. +class ToolResultMessageSignal(TypedSignal): + """Signal emitted when tool results are formatted as a message. - This event is fired when tool execution results are converted into a + This signal is fired when tool execution results are converted into a message format to be added to the conversation history. It provides access to the formatted message containing tool results. """ @@ -349,10 +333,10 @@ def _get_callback_fields(self) -> list[str]: return ["message"] -class ModelStreamEvent(TypedEvent): - """Event emitted during model response streaming. +class ModelStreamSignal(TypedSignal): + """Signal emitted during model response streaming. - This event is fired when the model produces streaming output during response + This signal is fired when the model produces streaming output during response generation. It provides access to incremental delta data from the model, allowing callbacks to process and display streaming responses in real-time. """ @@ -384,8 +368,8 @@ def _get_callback_fields(self) -> Any: return Any -class ToolUseStreamEvent(ModelStreamEvent): - """Event emitted during tool use input streaming.""" +class ToolUseStreamSignal(ModelStreamSignal): + """Signal emitted during tool use input streaming.""" def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: """Initialize with delta and current tool use state.""" @@ -402,8 +386,8 @@ def current_tool_use(self) -> dict[str, Any]: return self.get("current_tool_use") # type: ignore[return-value] -class TextStreamEvent(ModelStreamEvent): - """Event emitted during text content streaming.""" +class TextStreamSignal(ModelStreamSignal): + """Signal emitted during text content streaming.""" def __init__(self, delta: ContentBlockDelta, text: str) -> None: """Initialize with delta and text content.""" @@ -420,8 +404,8 @@ def delta(self) -> ContentBlockDelta: return self.get("delta") # type: ignore[return-value] -class ReasoningTextStreamEvent(ModelStreamEvent): - """Event emitted during reasoning text streaming.""" +class ReasoningTextStreamSignal(ModelStreamSignal): + """Signal emitted during reasoning text streaming.""" def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: """Initialize with delta and reasoning text.""" @@ -438,8 +422,8 @@ def delta(self) -> ContentBlockDelta: return self.get("delta") # type: ignore[return-value] -class ReasoningSignatureStreamEvent(ModelStreamEvent): - """Event emitted during reasoning signature streaming.""" +class ReasoningSignatureStreamSignal(ModelStreamSignal): + """Signal emitted during reasoning signature streaming.""" def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: """Initialize with delta and reasoning signature.""" @@ -456,24 +440,24 @@ def delta(self) -> ContentBlockDelta: return self.get("delta") # type: ignore[return-value] -AllTypedEvents = ( - InitEventLoopEvent - | StartEvent - | StartEventLoopEvent - | MessageEvent - | EventLoopThrottleDelay - | ForceStopEvent - | StopEvent - | ToolStreamEvent - | ToolResultEvent - | ToolResultMessageEvent - | ModelStreamEvent - | ToolUseStreamEvent - | TextStreamEvent - | ReasoningTextStreamEvent - | ReasoningSignatureStreamEvent - | TypedEvent +AgentSignal = ( + InitSignalLoopSignal + | StartSignal + | StartEventLoopSignal + | MessageSignal + | EventLoopThrottleSignal + | ForceStopSignal + | StopSignal + | ToolStreamSignal + | ToolResultSignal + | ToolResultMessageSignal + | ModelStreamSignal + | ToolUseStreamSignal + | TextStreamSignal + | ReasoningTextStreamSignal + | ReasoningSignatureStreamSignal + | TypedSignal ) -TypedToolGenerator = AsyncGenerator[TypedEvent, None] -"""Generator of tool events where all events are typed as TypedEvents.""" +SignalToolGenerator = AsyncGenerator[TypedSignal, None] +"""Generator of tool signals where all signals are typed as TypedSignals.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fc310cb4e..6cd7ee0db 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -20,14 +20,14 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.events import ( - InitEventLoopEvent, - ModelStreamEvent, - StopEvent, - ToolResultEvent, -) from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType +from strands.types.signals import ( + InitSignalLoopSignal, + ModelStreamSignal, + StopSignal, + ToolResultSignal, +) from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -408,7 +408,7 @@ async def check_invocation_state(**kwargs): assert invocation_state["agent"] == agent # Return expected values from event_loop_cycle - yield StopEvent( + yield StopSignal( stop_reason="stop", message={"role": "assistant", "content": [{"text": "Response"}]}, request_state={}, @@ -901,7 +901,7 @@ def test_agent_init_with_no_model_or_model_id(): def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([ToolResultEvent({})]) + mock_run_tool.return_value = agenerator([ToolResultSignal({})]) @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: @@ -925,7 +925,7 @@ def function(system_prompt: str) -> str: def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([ToolResultEvent({})]) + mock_run_tool.return_value = agenerator([ToolResultSignal({})]) tool_name = "system-prompter" @@ -1165,12 +1165,12 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): # Define the side effect to simulate callback handler being called multiple times async def test_event_loop(*args, **kwargs): - yield ModelStreamEvent(delta_data={"data": "First chunk"}) - yield ModelStreamEvent(delta_data={"data": "Second chunk"}) - yield ModelStreamEvent(delta_data={"data": "Final chunk", "complete": True}) + yield ModelStreamSignal(delta_data={"data": "First chunk"}) + yield ModelStreamSignal(delta_data={"data": "Second chunk"}) + yield ModelStreamSignal(delta_data={"data": "Final chunk", "complete": True}) # Return expected values from event_loop_cycle - yield StopEvent( + yield StopSignal( stop_reason="stop", message={"role": "assistant", "content": [{"text": "Response"}]}, metrics={}, @@ -1183,16 +1183,16 @@ async def test_event_loop(*args, **kwargs): stream = agent.stream_async("test message", callback_handler=mock_callback) # TODO - init_event = InitEventLoopEvent() + init_event = InitSignalLoopSignal() init_event.update({"callback_handler": mock_callback}) tru_events = await alist(stream) exp_events = [ init_event, - ModelStreamEvent(delta_data={"data": "First chunk"}), - ModelStreamEvent(delta_data={"data": "Second chunk"}), - ModelStreamEvent(delta_data={"data": "Final chunk", "complete": True}), - StopEvent( + ModelStreamSignal(delta_data={"data": "First chunk"}), + ModelStreamSignal(delta_data={"data": "Second chunk"}), + ModelStreamSignal(delta_data={"data": "Final chunk", "complete": True}), + StopSignal( stop_reason="stop", message={"role": "assistant", "content": [{"text": "Response"}]}, metrics={}, @@ -1336,7 +1336,7 @@ async def check_invocation_state(**kwargs): invocation_state = kwargs["invocation_state"] assert invocation_state["some_value"] == "a_value" # Return expected values from event_loop_cycle - yield StopEvent( + yield StopSignal( stop_reason="stop", message={"role": "assistant", "content": [{"text": "Response"}]}, metrics={}, @@ -1473,7 +1473,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_get_tracer.return_value = mock_tracer async def test_event_loop(*args, **kwargs): - yield StopEvent( + yield StopSignal( stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index cbf621cdb..7f96f85d7 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -20,13 +20,13 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.events import ForceStopEvent, StopEvent, ToolResultEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, ModelThrottledException, ) +from strands.types.signals import ForceStopSignal, StopSignal, ToolResultSignal from tests.fixtures.mock_hook_provider import MockHookProvider @@ -487,7 +487,7 @@ async def test_cycle_exception( ] tru_stop_event = None - exp_stop_event = ForceStopEvent(reason="Invalid error presented") + exp_stop_event = ForceStopSignal(reason="Invalid error presented") with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( @@ -789,7 +789,7 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a # Set up mock to return a valid response mock_recurse.return_value = agenerator( [ - StopEvent( + StopSignal( stop_reason="end_turn", message={"role": "assistant", "content": [{"text": "test text"}]}, metrics=strands.telemetry.metrics.EventLoopMetrics(), @@ -825,7 +825,7 @@ async def test_run_tool(agent, tool, alist): ) tru_result = (await alist(process))[-1] - exp_result = ToolResultEvent({"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]}) + exp_result = ToolResultSignal({"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]}) assert tru_result == exp_result @@ -840,7 +840,7 @@ async def test_run_tool_missing_tool(agent, alist): tru_events = await alist(process) exp_events = [ - ToolResultEvent( + ToolResultSignal( tool_result={ "toolUseId": "missing", "status": "error", @@ -962,7 +962,7 @@ def modify_hook(event: BeforeToolInvocationEvent): result = (await alist(process))[-1] # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2) - assert result == ToolResultEvent({"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]}) + assert result == ToolResultSignal({"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]}) assert hook_provider.events_received[1] == AfterToolInvocationEvent( agent=agent, @@ -993,7 +993,7 @@ def modify_hook(event: AfterToolInvocationEvent): ) result = (await alist(process))[-1] - assert result == ToolResultEvent(updated_result) + assert result == ToolResultSignal(updated_result) @pytest.mark.asyncio @@ -1015,7 +1015,7 @@ def modify_hook(event: AfterToolInvocationEvent): ) result = (await alist(process))[-1] - assert result == ToolResultEvent(tool_result=updated_result) + assert result == ToolResultSignal(tool_result=updated_result) @pytest.mark.asyncio @@ -1056,7 +1056,7 @@ def after_tool_call(self, event: AfterToolInvocationEvent): result = (await alist(process))[-1] - assert result == ToolResultEvent( + assert result == ToolResultSignal( tool_result={ "status": "error", "toolUseId": "test", diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index df7c710d2..e058b66c2 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -6,7 +6,7 @@ import strands import strands.telemetry from strands.types.content import Message -from strands.types.events import ToolResultEvent, ToolStreamEvent +from strands.types.signals import ToolResultSignal, ToolStreamSignal @pytest.fixture(autouse=True) @@ -17,8 +17,8 @@ def moto_autouse(moto_env): @pytest.fixture def tool_handler(request): async def handler(tool_use): - yield ToolStreamEvent(tool_use, {"event": "abc"}) - yield ToolResultEvent( + yield ToolStreamSignal(tool_use, {"event": "abc"}) + yield ToolResultSignal( { **params, "toolUseId": tool_use["toolUseId"], @@ -89,8 +89,8 @@ async def test_run_tools( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_uses[0], {"event": "abc"}), - ToolResultEvent( + ToolStreamSignal(tool_uses[0], {"event": "abc"}), + ToolResultSignal( { "content": [ {