Skip to content

Commit f814458

Browse files
authored
feat(hooks): Add invocation state (#1550)
1 parent e8fc991 commit f814458

File tree

8 files changed

+94
-29
lines changed

8 files changed

+94
-29
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu
474474
category=DeprecationWarning,
475475
stacklevel=2,
476476
)
477-
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
477+
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self, invocation_state={}))
478478
with self.tracer.tracer.start_as_current_span(
479479
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
480480
) as structured_output_span:
@@ -515,7 +515,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu
515515
return event["output"]
516516

517517
finally:
518-
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
518+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={}))
519519

520520
def cleanup(self) -> None:
521521
"""Clean up resources used by the agent.
@@ -657,7 +657,7 @@ async def _run_loop(
657657
Events from the event loop cycle.
658658
"""
659659
before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
660-
BeforeInvocationEvent(agent=self, messages=messages)
660+
BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages)
661661
)
662662
messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages
663663

@@ -695,7 +695,9 @@ async def _run_loop(
695695

696696
finally:
697697
self.conversation_manager.apply_management(self)
698-
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result))
698+
await self.hooks.invoke_callbacks_async(
699+
AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result)
700+
)
699701

700702
async def _execute_event_loop_cycle(
701703
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None

src/strands/event_loop/event_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ async def _handle_model_execution(
318318
await agent.hooks.invoke_callbacks_async(
319319
BeforeModelCallEvent(
320320
agent=agent,
321+
invocation_state=invocation_state,
321322
)
322323
)
323324

@@ -343,6 +344,7 @@ async def _handle_model_execution(
343344

344345
after_model_call_event = AfterModelCallEvent(
345346
agent=agent,
347+
invocation_state=invocation_state,
346348
stop_response=AfterModelCallEvent.ModelStopResponse(
347349
stop_reason=stop_reason,
348350
message=message,
@@ -370,6 +372,7 @@ async def _handle_model_execution(
370372
# Exception is automatically recorded by use_span with end_on_exit=True
371373
after_model_call_event = AfterModelCallEvent(
372374
agent=agent,
375+
invocation_state=invocation_state,
373376
exception=e,
374377
)
375378
await agent.hooks.invoke_callbacks_async(after_model_call_event)

src/strands/hooks/events.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import uuid
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any
99

1010
from typing_extensions import override
@@ -48,10 +48,14 @@ class BeforeInvocationEvent(HookEvent):
4848
- Agent.structured_output
4949
5050
Attributes:
51+
invocation_state: State and configuration passed through the agent invocation.
52+
This can include shared context for multi-agent coordination, request tracking,
53+
and dynamic configuration.
5154
messages: The input messages for this invocation. Can be modified by hooks
5255
to redact or transform content before processing.
5356
"""
5457

58+
invocation_state: dict[str, Any] = field(default_factory=dict)
5559
messages: Messages | None = None
5660

5761
def _can_write(self, name: str) -> bool:
@@ -75,11 +79,15 @@ class AfterInvocationEvent(HookEvent):
7579
- Agent.structured_output
7680
7781
Attributes:
82+
invocation_state: State and configuration passed through the agent invocation.
83+
This can include shared context for multi-agent coordination, request tracking,
84+
and dynamic configuration.
7885
result: The result of the agent invocation, if available.
7986
This will be None when invoked from structured_output methods, as those return typed output directly rather
8087
than AgentResult.
8188
"""
8289

90+
invocation_state: dict[str, Any] = field(default_factory=dict)
8391
result: "AgentResult | None" = None
8492

8593
@property
@@ -208,9 +216,14 @@ class BeforeModelCallEvent(HookEvent):
208216
that will be sent to the model.
209217
210218
Note: This event is not fired for invocations to structured_output.
219+
220+
Attributes:
221+
invocation_state: State and configuration passed through the agent invocation.
222+
This can include shared context for multi-agent coordination, request tracking,
223+
and dynamic configuration.
211224
"""
212225

213-
pass
226+
invocation_state: dict[str, Any] = field(default_factory=dict)
214227

215228

216229
@dataclass
@@ -239,6 +252,9 @@ class AfterModelCallEvent(HookEvent):
239252
conversation history
240253
241254
Attributes:
255+
invocation_state: State and configuration passed through the agent invocation.
256+
This can include shared context for multi-agent coordination, request tracking,
257+
and dynamic configuration.
242258
stop_response: The model response data if invocation was successful, None if failed.
243259
exception: Exception if the model invocation failed, None if successful.
244260
retry: Whether to retry the model invocation. Can be set by hook callbacks
@@ -258,6 +274,7 @@ class ModelStopResponse:
258274
message: Message
259275
stop_reason: StopReason
260276

277+
invocation_state: dict[str, Any] = field(default_factory=dict)
261278
stop_response: ModelStopResponse | None = None
262279
exception: Exception | None = None
263280
retry: bool = False

tests/strands/agent/hooks/test_events.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from strands.agent.agent_result import AgentResult
66
from strands.hooks import (
77
AfterInvocationEvent,
8+
AfterModelCallEvent,
89
AfterToolCallEvent,
910
AgentInitializedEvent,
1011
BeforeInvocationEvent,
12+
BeforeModelCallEvent,
1113
BeforeToolCallEvent,
1214
MessageAddedEvent,
1315
)
@@ -170,6 +172,41 @@ def test_after_invocation_event_properties_not_writable(agent):
170172
with pytest.raises(AttributeError, match="Property agent is not writable"):
171173
event.agent = Mock()
172174

175+
with pytest.raises(AttributeError, match="Property invocation_state is not writable"):
176+
event.invocation_state = {}
177+
178+
179+
def test_invocation_state_is_available_in_invocation_events(agent):
180+
"""Test that invocation_state is accessible in BeforeInvocationEvent and AfterInvocationEvent."""
181+
invocation_state = {"session_id": "test-123", "request_id": "req-456"}
182+
183+
before_event = BeforeInvocationEvent(agent=agent, invocation_state=invocation_state)
184+
assert before_event.invocation_state == invocation_state
185+
assert before_event.invocation_state["session_id"] == "test-123"
186+
assert before_event.invocation_state["request_id"] == "req-456"
187+
188+
after_event = AfterInvocationEvent(agent=agent, invocation_state=invocation_state, result=None)
189+
assert after_event.invocation_state == invocation_state
190+
assert after_event.invocation_state["session_id"] == "test-123"
191+
assert after_event.invocation_state["request_id"] == "req-456"
192+
193+
194+
def test_invocation_state_is_available_in_model_call_events(agent):
195+
"""Test that invocation_state is accessible in BeforeModelCallEvent and AfterModelCallEvent."""
196+
invocation_state = {"session_id": "test-123", "request_id": "req-456"}
197+
198+
before_event = BeforeModelCallEvent(agent=agent, invocation_state=invocation_state)
199+
assert before_event.invocation_state == invocation_state
200+
assert before_event.invocation_state["session_id"] == "test-123"
201+
assert before_event.invocation_state["request_id"] == "req-456"
202+
203+
after_event = AfterModelCallEvent(agent=agent, invocation_state=invocation_state)
204+
assert after_event.invocation_state == invocation_state
205+
assert after_event.invocation_state["session_id"] == "test-123"
206+
assert after_event.invocation_state["request_id"] == "req-456"
207+
208+
209+
173210

174211
def test_before_invocation_event_messages_default_none(agent):
175212
"""Test that BeforeInvocationEvent.messages defaults to None for backward compatibility."""

tests/strands/agent/test_agent_hooks.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,15 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
160160

161161
assert length == 12
162162

163-
assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
163+
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1])
164164
assert next(events) == MessageAddedEvent(
165165
agent=agent,
166166
message=agent.messages[0],
167167
)
168-
assert next(events) == BeforeModelCallEvent(agent=agent)
168+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
169169
assert next(events) == AfterModelCallEvent(
170170
agent=agent,
171+
invocation_state=ANY,
171172
stop_response=AfterModelCallEvent.ModelStopResponse(
172173
message={
173174
"content": [{"toolUse": tool_use}],
@@ -193,9 +194,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
193194
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
194195
)
195196
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
196-
assert next(events) == BeforeModelCallEvent(agent=agent)
197+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
197198
assert next(events) == AfterModelCallEvent(
198199
agent=agent,
200+
invocation_state=ANY,
199201
stop_response=AfterModelCallEvent.ModelStopResponse(
200202
message=mock_model.agent_responses[1],
201203
stop_reason="end_turn",
@@ -204,7 +206,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
204206
)
205207
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
206208

207-
assert next(events) == AfterInvocationEvent(agent=agent, result=result)
209+
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result)
208210

209211
assert len(agent.messages) == 4
210212

@@ -215,8 +217,9 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
215217
iterator = agent.stream_async("test message")
216218
await anext(iterator)
217219

218-
# Verify first event is BeforeInvocationEvent with messages
220+
# Verify first event is BeforeInvocationEvent with invocation_state and messages
219221
assert len(hook_provider.events_received) == 1
222+
assert hook_provider.events_received[0].invocation_state is not None
220223
assert hook_provider.events_received[0].messages is not None
221224
assert hook_provider.events_received[0].messages[0]["role"] == "user"
222225

@@ -230,14 +233,15 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
230233

231234
assert length == 12
232235

233-
assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
236+
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1])
234237
assert next(events) == MessageAddedEvent(
235238
agent=agent,
236239
message=agent.messages[0],
237240
)
238-
assert next(events) == BeforeModelCallEvent(agent=agent)
241+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
239242
assert next(events) == AfterModelCallEvent(
240243
agent=agent,
244+
invocation_state=ANY,
241245
stop_response=AfterModelCallEvent.ModelStopResponse(
242246
message={
243247
"content": [{"toolUse": tool_use}],
@@ -263,9 +267,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
263267
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
264268
)
265269
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
266-
assert next(events) == BeforeModelCallEvent(agent=agent)
270+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
267271
assert next(events) == AfterModelCallEvent(
268272
agent=agent,
273+
invocation_state=ANY,
269274
stop_response=AfterModelCallEvent.ModelStopResponse(
270275
message=mock_model.agent_responses[1],
271276
stop_reason="end_turn",
@@ -274,7 +279,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
274279
)
275280
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
276281

277-
assert next(events) == AfterInvocationEvent(agent=agent, result=result)
282+
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result)
278283

279284
assert len(agent.messages) == 4
280285

@@ -289,8 +294,8 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
289294

290295
assert length == 2
291296

292-
assert next(events) == BeforeInvocationEvent(agent=agent)
293-
assert next(events) == AfterInvocationEvent(agent=agent)
297+
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY)
298+
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY)
294299

295300
assert len(agent.messages) == 0 # no new messages added
296301

@@ -306,8 +311,8 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
306311

307312
assert length == 2
308313

309-
assert next(events) == BeforeInvocationEvent(agent=agent)
310-
assert next(events) == AfterInvocationEvent(agent=agent)
314+
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY)
315+
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY)
311316

312317
assert len(agent.messages) == 0 # no new messages added
313318

tests/strands/agent/test_conversation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_per_turn_dynamic_change():
362362

363363
mock_agent = MagicMock()
364364
mock_agent.messages = []
365-
event = BeforeModelCallEvent(agent=mock_agent)
365+
event = BeforeModelCallEvent(agent=mock_agent, invocation_state={})
366366

367367
# Initially disabled
368368
with patch.object(manager, "apply_management") as mock_apply:

tests/strands/event_loop/test_event_loop.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -855,27 +855,28 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model,
855855
assert count == 9
856856

857857
# 1st call - throttled
858-
assert next(events) == BeforeModelCallEvent(agent=agent)
859-
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
858+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
859+
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
860860
expected_after.retry = True
861861
assert next(events) == expected_after
862862

863863
# 2nd call - throttled
864-
assert next(events) == BeforeModelCallEvent(agent=agent)
865-
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
864+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
865+
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
866866
expected_after.retry = True
867867
assert next(events) == expected_after
868868

869869
# 3rd call - throttled
870-
assert next(events) == BeforeModelCallEvent(agent=agent)
871-
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
870+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
871+
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
872872
expected_after.retry = True
873873
assert next(events) == expected_after
874874

875875
# 4th call - successful
876-
assert next(events) == BeforeModelCallEvent(agent=agent)
876+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
877877
assert next(events) == AfterModelCallEvent(
878878
agent=agent,
879+
invocation_state=ANY,
879880
stop_response=AfterModelCallEvent.ModelStopResponse(
880881
message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn"
881882
),

tests/strands/experimental/hooks/test_hook_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def test_after_tool_call_event_type_equality():
6868

6969
def test_before_model_call_event_type_equality():
7070
"""Verify that BeforeModelInvocationEvent alias has the same type identity."""
71-
before_model_event = BeforeModelCallEvent(agent=Mock())
71+
before_model_event = BeforeModelCallEvent(agent=Mock(), invocation_state={})
7272

7373
assert isinstance(before_model_event, BeforeModelInvocationEvent)
7474
assert isinstance(before_model_event, BeforeModelCallEvent)
7575

7676

7777
def test_after_model_call_event_type_equality():
7878
"""Verify that AfterModelInvocationEvent alias has the same type identity."""
79-
after_model_event = AfterModelCallEvent(agent=Mock())
79+
after_model_event = AfterModelCallEvent(agent=Mock(), invocation_state={})
8080

8181
assert isinstance(after_model_event, AfterModelInvocationEvent)
8282
assert isinstance(after_model_event, AfterModelCallEvent)

0 commit comments

Comments
 (0)