Skip to content

Commit 7fbc9dc

Browse files
JackYPCOnlineUnshurepgrayyposhinchenvamgan
authored
feat: replace kwargs with invocation_state in agent APIs (#966)
* feat: replace kwargs with invocation_state in agent APIs * fix: handle **kwargs in stream_async. * feat: add a unit test for the change * Update src/strands/agent/agent.py Co-authored-by: Nick Clegg <[email protected]> * tool - executors - concurrent - remove no-op gather (#954) * feat(telemetry): updated traces to match OTEL v1.37 semantic conventions (#952) * event loop - handle model execution (#958) * feat: implement concurrent message reading for session managers (#897) Replace sequential message loading with async concurrent reading in both S3SessionManager and FileSessionManager to improve performance for long conversations. Uses asyncio.gather() with run_in_executor() to read multiple messages simultaneously while maintaining proper ordering. Resolves: #874 Co-authored-by: Vamil Gandhi <[email protected]> * hooks - before tool call event - cancel tool (#964) * fix(telemetry): removed double serialization for events (#977) * fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException (#994) * feat: add more tests and adjust invocation_state dic structure * Apply suggestion from @Unshure Co-authored-by: Nick Clegg <[email protected]> * fix: adjust **kwargs in multiagent primitives --------- Co-authored-by: Nick Clegg <[email protected]> Co-authored-by: Patrick Gray <[email protected]> Co-authored-by: poshinchen <[email protected]> Co-authored-by: Vamil Gandhi <[email protected]> Co-authored-by: Vamil Gandhi <[email protected]> Co-authored-by: ratish <[email protected]>
1 parent 419de19 commit 7fbc9dc

File tree

8 files changed

+109
-26
lines changed

8 files changed

+109
-26
lines changed

src/strands/agent/agent.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import json
1414
import logging
1515
import random
16+
import warnings
1617
from concurrent.futures import ThreadPoolExecutor
1718
from typing import (
1819
Any,
@@ -374,7 +375,9 @@ def tool_names(self) -> list[str]:
374375
all_tools = self.tool_registry.get_all_tools_config()
375376
return list(all_tools.keys())
376377

377-
def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
378+
def __call__(
379+
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
380+
) -> AgentResult:
378381
"""Process a natural language prompt through the agent's event loop.
379382
380383
This method implements the conversational interface with multiple input patterns:
@@ -389,7 +392,8 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
389392
- list[ContentBlock]: Multi-modal content blocks
390393
- list[Message]: Complete messages with roles
391394
- None: Use existing conversation history
392-
**kwargs: Additional parameters to pass through the event loop.
395+
invocation_state: Additional parameters to pass through the event loop.
396+
**kwargs: Additional parameters to pass through the event loop.[Deprecating]
393397
394398
Returns:
395399
Result object containing:
@@ -401,13 +405,15 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
401405
"""
402406

403407
def execute() -> AgentResult:
404-
return asyncio.run(self.invoke_async(prompt, **kwargs))
408+
return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs))
405409

406410
with ThreadPoolExecutor() as executor:
407411
future = executor.submit(execute)
408412
return future.result()
409413

410-
async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
414+
async def invoke_async(
415+
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
416+
) -> AgentResult:
411417
"""Process a natural language prompt through the agent's event loop.
412418
413419
This method implements the conversational interface with multiple input patterns:
@@ -422,7 +428,8 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
422428
- list[ContentBlock]: Multi-modal content blocks
423429
- list[Message]: Complete messages with roles
424430
- None: Use existing conversation history
425-
**kwargs: Additional parameters to pass through the event loop.
431+
invocation_state: Additional parameters to pass through the event loop.
432+
**kwargs: Additional parameters to pass through the event loop.[Deprecating]
426433
427434
Returns:
428435
Result: object containing:
@@ -432,7 +439,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
432439
- metrics: Performance metrics from the event loop
433440
- state: The final state of the event loop
434441
"""
435-
events = self.stream_async(prompt, **kwargs)
442+
events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs)
436443
async for event in events:
437444
_ = event
438445

@@ -528,9 +535,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
528535
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
529536

530537
async def stream_async(
531-
self,
532-
prompt: AgentInput = None,
533-
**kwargs: Any,
538+
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
534539
) -> AsyncIterator[Any]:
535540
"""Process a natural language prompt and yield events as an async iterator.
536541
@@ -546,7 +551,8 @@ async def stream_async(
546551
- list[ContentBlock]: Multi-modal content blocks
547552
- list[Message]: Complete messages with roles
548553
- None: Use existing conversation history
549-
**kwargs: Additional parameters to pass to the event loop.
554+
invocation_state: Additional parameters to pass through the event loop.
555+
**kwargs: Additional parameters to pass to the event loop.[Deprecating]
550556
551557
Yields:
552558
An async iterator that yields events. Each event is a dictionary containing
@@ -567,7 +573,19 @@ async def stream_async(
567573
yield event["data"]
568574
```
569575
"""
570-
callback_handler = kwargs.get("callback_handler", self.callback_handler)
576+
merged_state = {}
577+
if kwargs:
578+
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
579+
merged_state.update(kwargs)
580+
if invocation_state is not None:
581+
merged_state["invocation_state"] = invocation_state
582+
else:
583+
if invocation_state is not None:
584+
merged_state = invocation_state
585+
586+
callback_handler = self.callback_handler
587+
if kwargs:
588+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
571589

572590
# Process input and get message to add (if any)
573591
messages = self._convert_prompt_to_messages(prompt)
@@ -576,10 +594,10 @@ async def stream_async(
576594

577595
with trace_api.use_span(self.trace_span):
578596
try:
579-
events = self._run_loop(messages, invocation_state=kwargs)
597+
events = self._run_loop(messages, invocation_state=merged_state)
580598

581599
async for event in events:
582-
event.prepare(invocation_state=kwargs)
600+
event.prepare(invocation_state=merged_state)
583601

584602
if event.is_callback_event:
585603
as_dict = event.as_dict()

src/strands/multiagent/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import asyncio
7+
import warnings
78
from abc import ABC, abstractmethod
89
from concurrent.futures import ThreadPoolExecutor
910
from dataclasses import dataclass, field
@@ -111,8 +112,12 @@ def __call__(
111112
if invocation_state is None:
112113
invocation_state = {}
113114

115+
if kwargs:
116+
invocation_state.update(kwargs)
117+
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
118+
114119
def execute() -> MultiAgentResult:
115-
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))
120+
return asyncio.run(self.invoke_async(task, invocation_state))
116121

117122
with ThreadPoolExecutor() as executor:
118123
future = executor.submit(execute)

src/strands/multiagent/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
572572
elif isinstance(node.executor, Agent):
573573
if self.node_timeout is not None:
574574
agent_response = await asyncio.wait_for(
575-
node.executor.invoke_async(node_input, **invocation_state),
575+
node.executor.invoke_async(node_input, invocation_state=invocation_state),
576576
timeout=self.node_timeout,
577577
)
578578
else:
579-
agent_response = await node.executor.invoke_async(node_input, **invocation_state)
579+
agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state)
580580

581581
# Extract metrics from agent response
582582
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)

src/strands/multiagent/swarm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,7 @@ async def _execute_node(
635635
# Execute node
636636
result = None
637637
node.reset_executor_state()
638-
# Unpacking since this is the agent class. Other executors should not unpack
639-
result = await node.executor.invoke_async(node_input, **invocation_state)
638+
result = await node.executor.invoke_async(node_input, invocation_state=invocation_state)
640639

641640
execution_time = round((time.time() - start_time) * 1000)
642641

tests/strands/agent/test_agent.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import textwrap
66
import unittest.mock
7+
import warnings
78
from uuid import uuid4
89

910
import pytest
@@ -1877,3 +1878,58 @@ def test_tool(action: str) -> str:
18771878
assert '"action": "test_value"' in tool_call_text
18781879
assert '"agent"' not in tool_call_text
18791880
assert '"extra_param"' not in tool_call_text
1881+
1882+
1883+
def test_agent__call__handles_none_invocation_state(mock_model, agent):
1884+
"""Test that agent handles None invocation_state without AttributeError."""
1885+
mock_model.mock_stream.return_value = [
1886+
{"contentBlockDelta": {"delta": {"text": "test response"}}},
1887+
{"contentBlockStop": {}},
1888+
]
1889+
1890+
# This should not raise AttributeError: 'NoneType' object has no attribute 'get'
1891+
result = agent("test", invocation_state=None)
1892+
1893+
assert result.message["content"][0]["text"] == "test response"
1894+
assert result.stop_reason == "end_turn"
1895+
1896+
1897+
def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle):
1898+
"""Test that kwargs trigger deprecation warning and are merged correctly with invocation_state."""
1899+
1900+
async def check_invocation_state(**kwargs):
1901+
invocation_state = kwargs["invocation_state"]
1902+
# Should have nested structure when both invocation_state and kwargs are provided
1903+
assert invocation_state["invocation_state"] == {"my": "state"}
1904+
assert invocation_state["other_kwarg"] == "foobar"
1905+
yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})
1906+
1907+
mock_event_loop_cycle.side_effect = check_invocation_state
1908+
1909+
with warnings.catch_warnings(record=True) as captured_warnings:
1910+
warnings.simplefilter("always")
1911+
agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar")
1912+
1913+
# Verify deprecation warning was issued
1914+
assert len(captured_warnings) == 1
1915+
assert issubclass(captured_warnings[0].category, UserWarning)
1916+
assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message)
1917+
1918+
1919+
def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle):
1920+
"""Test that using only invocation_state does not trigger warning and passes state directly."""
1921+
1922+
async def check_invocation_state(**kwargs):
1923+
invocation_state = kwargs["invocation_state"]
1924+
1925+
assert invocation_state["my"] == "state"
1926+
assert "agent" in invocation_state
1927+
yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})
1928+
1929+
mock_event_loop_cycle.side_effect = check_invocation_state
1930+
1931+
with warnings.catch_warnings(record=True) as captured_warnings:
1932+
warnings.simplefilter("always")
1933+
agent("hello!", invocation_state={"my": "state"})
1934+
1935+
assert len(captured_warnings) == 0

tests/strands/multiagent/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,18 @@ async def invoke_async(self, task, invocation_state, **kwargs):
159159
self.invoke_async_called = True
160160
self.received_task = task
161161
self.received_kwargs = kwargs
162+
self.received_invocation_state = invocation_state
162163
return MultiAgentResult(
163164
status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}
164165
)
165166

166167
agent = TestMultiAgent()
167168

168169
# Test with string task
169-
result = agent("test task", param1="value1", param2="value2")
170+
result = agent("test task", param1="value1", param2="value2", invocation_state={"value3": "value4"})
170171

171172
assert agent.invoke_async_called
172173
assert agent.received_task == "test task"
173-
assert agent.received_kwargs == {"param1": "value1", "param2": "value2"}
174+
assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"}
174175
assert isinstance(result, MultiAgentResult)
175176
assert result.status == Status.COMPLETED

tests/strands/multiagent/test_graph.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
310310
result = await graph.invoke_async([{"text": "Original task"}])
311311

312312
# Verify entry node was called with original task
313-
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}])
313+
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}], invocation_state={})
314314
assert result.status == Status.COMPLETED
315315
mock_strands_tracer.start_multiagent_span.assert_called()
316316
mock_use_span.assert_called_once()
@@ -906,7 +906,7 @@ def __init__(self, name):
906906
self._session_manager = None
907907
self.hooks = HookRegistry()
908908

909-
async def invoke_async(self, input_data):
909+
async def invoke_async(self, input_data, invocation_state=None):
910910
# Increment execution count in state
911911
count = self.state.get("execution_count") or 0
912912
self.state.set("execution_count", count + 1)
@@ -1300,7 +1300,9 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span):
13001300
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
13011301
result = await graph.invoke_async("Test kwargs passing", test_invocation_state)
13021302

1303-
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state)
1303+
kwargs_agent.invoke_async.assert_called_once_with(
1304+
[{"text": "Test kwargs passing"}], invocation_state=test_invocation_state
1305+
)
13041306
assert result.status == Status.COMPLETED
13051307

13061308

@@ -1335,5 +1337,7 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
13351337
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
13361338
result = graph("Test kwargs passing sync", test_invocation_state)
13371339

1338-
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state)
1340+
kwargs_agent.invoke_async.assert_called_once_with(
1341+
[{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state
1342+
)
13391343
assert result.status == Status.COMPLETED

tests/strands/multiagent/test_swarm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span):
558558
test_kwargs = {"custom_param": "test_value", "another_param": 42}
559559
result = await swarm.invoke_async("Test kwargs passing", test_kwargs)
560560

561-
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
561+
assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs}
562562
assert result.status == Status.COMPLETED
563563

564564

@@ -572,5 +572,5 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
572572
test_kwargs = {"custom_param": "test_value", "another_param": 42}
573573
result = swarm("Test kwargs passing sync", test_kwargs)
574574

575-
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
575+
assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs}
576576
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)