Skip to content

Commit 23f6080

Browse files
committed
feat: expose messages to BeforeInvocationEvent
Add incoming messages to BeforeInvocationEvent so this hook can be used for guardrails aiming to check (and perhaps redact) user inputs before kicking off the event loop.
1 parent 2a26ffa commit 23f6080

File tree

5 files changed

+54
-28
lines changed

5 files changed

+54
-28
lines changed

src/strands/agent/agent.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,24 +484,23 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
484484
Raises:
485485
ValueError: If no conversation history or prompt is provided.
486486
"""
487-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
487+
if not self.messages and not prompt:
488+
raise ValueError("No conversation history or prompt provided")
489+
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
490+
491+
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=temp_messages))
492+
488493
with self.tracer.tracer.start_as_current_span(
489-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
494+
"execute_structured_output",
495+
attributes={
496+
"gen_ai.system": "strands-agents",
497+
"gen_ai.agent.name": self.name,
498+
"gen_ai.agent.id": self.agent_id,
499+
"gen_ai.operation.name": "execute_structured_output",
500+
},
501+
kind=trace_api.SpanKind.CLIENT,
490502
) as structured_output_span:
491503
try:
492-
if not self.messages and not prompt:
493-
raise ValueError("No conversation history or prompt provided")
494-
495-
temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
496-
497-
structured_output_span.set_attributes(
498-
{
499-
"gen_ai.system": "strands-agents",
500-
"gen_ai.agent.name": self.name,
501-
"gen_ai.agent.id": self.agent_id,
502-
"gen_ai.operation.name": "execute_structured_output",
503-
}
504-
)
505504
if self.system_prompt:
506505
structured_output_span.add_event(
507506
"gen_ai.system.message",
@@ -606,7 +605,7 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any])
606605
Yields:
607606
Events from the event loop cycle.
608607
"""
609-
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
608+
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self, messages=messages))
610609

611610
try:
612611
yield InitEventLoopEvent()

src/strands/hooks/events.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any, Optional
88

9-
from ..types.content import Message
9+
from ..types.content import Message, Messages
1010
from ..types.streaming import StopReason
1111
from ..types.tools import AgentTool, ToolResult, ToolUse
1212
from .registry import HookEvent
@@ -36,9 +36,13 @@ class BeforeInvocationEvent(HookEvent):
3636
- Agent.__call__
3737
- Agent.stream_async
3838
- Agent.structured_output
39+
40+
Attributes:
41+
message: The input to the agent request (including current message and
42+
any history if present)
3943
"""
4044

41-
pass
45+
messages: Messages
4246

4347

4448
@dataclass

tests/strands/agent/hooks/test_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def initialized_event(agent):
4747

4848
@pytest.fixture
4949
def start_request_event(agent):
50-
return BeforeInvocationEvent(agent=agent)
50+
return BeforeInvocationEvent(agent=agent, messages=Mock())
5151

5252

5353
@pytest.fixture

tests/strands/agent/test_agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,13 +1007,15 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
10071007
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
10081008
)
10091009

1010-
mock_span.set_attributes.assert_called_once_with(
1011-
{
1010+
mock_otel_tracer.start_as_current_span.assert_called_once_with(
1011+
"execute_structured_output",
1012+
attributes={
10121013
"gen_ai.system": "strands-agents",
10131014
"gen_ai.agent.name": "Strands Agents",
10141015
"gen_ai.agent.id": "default",
10151016
"gen_ai.operation.name": "execute_structured_output",
1016-
}
1017+
},
1018+
kind=unittest.mock.ANY,
10171019
)
10181020

10191021
# ensure correct otel event messages are emitted

tests/strands/agent/test_agent_hooks.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
150150

151151
assert length == 12
152152

153-
assert next(events) == BeforeInvocationEvent(agent=agent)
153+
assert next(events) == BeforeInvocationEvent(
154+
agent=agent,
155+
messages=agent.messages[0:1],
156+
)
154157
assert next(events) == MessageAddedEvent(
155158
agent=agent,
156159
message=agent.messages[0],
@@ -199,9 +202,11 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
199202
@pytest.mark.asyncio
200203
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator):
201204
"""Verify that the correct hook events are emitted as part of stream_async."""
202-
iterator = agent.stream_async("test message")
205+
input_prompt = "test message"
206+
input_messages: Messages = [{"role": "user", "content": [{"text": input_prompt}]}]
207+
iterator = agent.stream_async(input_prompt)
203208
await anext(iterator)
204-
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)]
209+
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent, messages=input_messages)]
205210

206211
# iterate the rest
207212
async for _ in iterator:
@@ -211,7 +216,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
211216

212217
assert length == 12
213218

214-
assert next(events) == BeforeInvocationEvent(agent=agent)
219+
assert next(events) == BeforeInvocationEvent(agent=agent, messages=input_messages)
215220
assert next(events) == MessageAddedEvent(
216221
agent=agent,
217222
message=agent.messages[0],
@@ -267,7 +272,15 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
267272

268273
assert length == 2
269274

270-
assert next(events) == BeforeInvocationEvent(agent=agent)
275+
assert next(events) == BeforeInvocationEvent(
276+
agent=agent,
277+
messages=[
278+
{
279+
"content": [{"text": "example prompt"}],
280+
"role": "user",
281+
}
282+
],
283+
)
271284
assert next(events) == AfterInvocationEvent(agent=agent)
272285

273286
assert len(agent.messages) == 0 # no new messages added
@@ -284,7 +297,15 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
284297

285298
assert length == 2
286299

287-
assert next(events) == BeforeInvocationEvent(agent=agent)
300+
assert next(events) == BeforeInvocationEvent(
301+
agent=agent,
302+
messages=[
303+
{
304+
"content": [{"text": "example prompt"}],
305+
"role": "user",
306+
}
307+
],
308+
)
288309
assert next(events) == AfterInvocationEvent(agent=agent)
289310

290311
assert len(agent.messages) == 0 # no new messages added

0 commit comments

Comments
 (0)