Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,24 +484,23 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
Raises:
ValueError: If no conversation history or prompt is provided.
"""
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)

self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=temp_messages))

with self.tracer.tracer.start_as_current_span(
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
"execute_structured_output",
attributes={
"gen_ai.system": "strands-agents",
"gen_ai.agent.name": self.name,
"gen_ai.agent.id": self.agent_id,
"gen_ai.operation.name": "execute_structured_output",
},
kind=trace_api.SpanKind.CLIENT,
) as structured_output_span:
try:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")

temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)

structured_output_span.set_attributes(
{
"gen_ai.system": "strands-agents",
"gen_ai.agent.name": self.name,
"gen_ai.agent.id": self.agent_id,
"gen_ai.operation.name": "execute_structured_output",
}
)
if self.system_prompt:
structured_output_span.add_event(
"gen_ai.system.message",
Expand Down Expand Up @@ -606,7 +605,7 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any])
Yields:
Events from the event loop cycle.
"""
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=messages))

try:
yield InitEventLoopEvent()
Expand Down
8 changes: 6 additions & 2 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import Any, Optional

from ..types.content import Message
from ..types.content import Message, Messages
from ..types.streaming import StopReason
from ..types.tools import AgentTool, ToolResult, ToolUse
from .registry import HookEvent
Expand Down Expand Up @@ -36,9 +36,13 @@ class BeforeInvocationEvent(HookEvent):
- Agent.__call__
- Agent.stream_async
- Agent.structured_output

Attributes:
message: The input to the agent request (including current message and
any history if present)
"""

pass
messages: Messages


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion tests/strands/agent/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def initialized_event(agent):

@pytest.fixture
def start_request_event(agent):
return BeforeInvocationEvent(agent=agent)
return BeforeInvocationEvent(agent=agent, messages=Mock())


@pytest.fixture
Expand Down
8 changes: 5 additions & 3 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,13 +1007,15 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)

mock_span.set_attributes.assert_called_once_with(
{
mock_otel_tracer.start_as_current_span.assert_called_once_with(
"execute_structured_output",
attributes={
"gen_ai.system": "strands-agents",
"gen_ai.agent.name": "Strands Agents",
"gen_ai.agent.id": "default",
"gen_ai.operation.name": "execute_structured_output",
}
},
kind=unittest.mock.ANY,
)

# ensure correct otel event messages are emitted
Expand Down
33 changes: 27 additions & 6 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(
agent=agent,
messages=agent.messages[0:1],
)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
Expand Down Expand Up @@ -199,9 +202,11 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
@pytest.mark.asyncio
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator):
"""Verify that the correct hook events are emitted as part of stream_async."""
iterator = agent.stream_async("test message")
input_prompt = "test message"
input_messages: Messages = [{"role": "user", "content": [{"text": input_prompt}]}]
iterator = agent.stream_async(input_prompt)
await anext(iterator)
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)]
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent, messages=input_messages)]

# iterate the rest
async for _ in iterator:
Expand All @@ -211,7 +216,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(agent=agent, messages=input_messages)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
Expand Down Expand Up @@ -267,7 +272,15 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):

assert length == 2

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(
agent=agent,
messages=[
{
"content": [{"text": "example prompt"}],
"role": "user",
}
],
)
assert next(events) == AfterInvocationEvent(agent=agent)

assert len(agent.messages) == 0 # no new messages added
Expand All @@ -284,7 +297,15 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a

assert length == 2

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(
agent=agent,
messages=[
{
"content": [{"text": "example prompt"}],
"role": "user",
}
],
)
assert next(events) == AfterInvocationEvent(agent=agent)

assert len(agent.messages) == 0 # no new messages added