diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..5b32019bc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -484,24 +484,23 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Raises: ValueError: If no conversation history or prompt is provided. """ - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + if not self.messages and not prompt: + raise ValueError("No conversation history or prompt provided") + temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=temp_messages)) + with self.tracer.tracer.start_as_current_span( - "execute_structured_output", kind=trace_api.SpanKind.CLIENT + "execute_structured_output", + attributes={ + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": self.name, + "gen_ai.agent.id": self.agent_id, + "gen_ai.operation.name": "execute_structured_output", + }, + kind=trace_api.SpanKind.CLIENT, ) as structured_output_span: try: - if not self.messages and not prompt: - raise ValueError("No conversation history or prompt provided") - - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) - - structured_output_span.set_attributes( - { - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": self.name, - "gen_ai.agent.id": self.agent_id, - "gen_ai.operation.name": "execute_structured_output", - } - ) if self.system_prompt: structured_output_span.add_event( "gen_ai.system.message", @@ -606,7 +605,7 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) Yields: Events from the event loop cycle. """ - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=messages)) try: yield InitEventLoopEvent() diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8f611e4e2..b2e305074 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any, Optional -from ..types.content import Message +from ..types.content import Message, Messages from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -36,9 +36,13 @@ class BeforeInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + message: The input to the agent request (including current message and + any history if present) """ - pass + messages: Messages @dataclass diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 8bbd89c17..55b8280ba 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -47,7 +47,7 @@ def initialized_event(agent): @pytest.fixture def start_request_event(agent): - return BeforeInvocationEvent(agent=agent) + return BeforeInvocationEvent(agent=agent, messages=Mock()) @pytest.fixture diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..582f3f160 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1007,13 +1007,15 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) - mock_span.set_attributes.assert_called_once_with( - { + mock_otel_tracer.start_as_current_span.assert_called_once_with( + "execute_structured_output", + attributes={ "gen_ai.system": "strands-agents", "gen_ai.agent.name": "Strands Agents", "gen_ai.agent.id": "default", "gen_ai.operation.name": "execute_structured_output", - } + }, + kind=unittest.mock.ANY, ) # ensure correct otel event messages are emitted diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..4190fdd42 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -150,7 +150,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent( + agent=agent, + messages=agent.messages[0:1], + ) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -199,9 +202,11 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u @pytest.mark.asyncio async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator): """Verify that the correct hook events are emitted as part of stream_async.""" - iterator = agent.stream_async("test message") + input_prompt = "test message" + input_messages: Messages = [{"role": "user", "content": [{"text": input_prompt}]}] + iterator = agent.stream_async(input_prompt) await anext(iterator) - assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] + assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent, messages=input_messages)] # iterate the rest async for _ in iterator: @@ -211,7 +216,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=input_messages) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -267,7 +272,15 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent( + agent=agent, + messages=[ + { + "content": [{"text": "example prompt"}], + "role": "user", + } + ], + ) assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 0 # no new messages added @@ -284,7 +297,15 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent( + agent=agent, + messages=[ + { + "content": [{"text": "example prompt"}], + "role": "user", + } + ], + ) assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 0 # no new messages added