Skip to content

Commit 97a062e

Browse files
committed
tool executors
1 parent 72709cf commit 97a062e

File tree

22 files changed

+997
-1024
lines changed

22 files changed

+997
-1024
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from opentelemetry import trace as trace_api
2020
from pydantic import BaseModel
2121

22-
from ..event_loop.event_loop import event_loop_cycle, run_tool
22+
from ..event_loop.event_loop import event_loop_cycle
23+
from ..experimental.tools.executors import Executor as ToolExecutor
2324
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2425
from ..hooks import (
2526
AfterInvocationEvent,
@@ -34,6 +35,7 @@
3435
from ..session.session_manager import SessionManager
3536
from ..telemetry.metrics import EventLoopMetrics
3637
from ..telemetry.tracer import get_tracer
38+
from ..tools import executors as tool_executors
3739
from ..tools.registry import ToolRegistry
3840
from ..tools.watcher import ToolWatcher
3941
from ..types.content import ContentBlock, Message, Messages
@@ -135,10 +137,12 @@ def caller(
135137
"name": normalized_name,
136138
"input": kwargs.copy(),
137139
}
140+
invocation_state = {**kwargs, "skip_tracing": True, "tool_results": []}
141+
142+
tool_executor = tool_executors.sequential.Executor()
138143

139144
async def acall() -> ToolResult:
140-
# Pass kwargs as invocation_state
141-
async for event in run_tool(self._agent, tool_use, kwargs):
145+
async for event in tool_executor.stream(self._agent, tool_use, invocation_state):
142146
_ = event
143147

144148
return cast(ToolResult, event)
@@ -207,6 +211,7 @@ def __init__(
207211
state: Optional[Union[AgentState, dict]] = None,
208212
hooks: Optional[list[HookProvider]] = None,
209213
session_manager: Optional[SessionManager] = None,
214+
tool_executor: Optional[ToolExecutor] = None,
210215
):
211216
"""Initialize the Agent with the specified configuration.
212217
@@ -249,6 +254,7 @@ def __init__(
249254
Defaults to None.
250255
session_manager: Manager for handling agent sessions including conversation history and state.
251256
If provided, enables session-based persistence and state management.
257+
tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.).
252258
"""
253259
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
254260
self.messages = messages if messages is not None else []
@@ -320,6 +326,8 @@ def __init__(
320326
if self._session_manager:
321327
self.hooks.add_hook(self._session_manager)
322328

329+
self.tool_executor = tool_executor or tool_executors.concurrent.Executor()
330+
323331
if hooks:
324332
for hook in hooks:
325333
self.hooks.add_hook(hook)

src/strands/event_loop/event_loop.py

Lines changed: 19 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@
1111
import logging
1212
import time
1313
import uuid
14-
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
14+
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

1616
from opentelemetry import trace as trace_api
1717

1818
from ..experimental.hooks import (
1919
AfterModelInvocationEvent,
20-
AfterToolInvocationEvent,
2120
BeforeModelInvocationEvent,
22-
BeforeToolInvocationEvent,
2321
)
2422
from ..hooks import (
2523
MessageAddedEvent,
2624
)
2725
from ..telemetry.metrics import Trace
2826
from ..telemetry.tracer import get_tracer
29-
from ..tools.executor import run_tools, validate_and_prepare_tools
27+
from ..tools.validator import validate_and_prepare_tools
3028
from ..types.content import Message
3129
from ..types.exceptions import (
3230
ContextWindowOverflowException,
@@ -35,7 +33,7 @@
3533
ModelThrottledException,
3634
)
3735
from ..types.streaming import Metrics, StopReason
38-
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
36+
from ..types.tools import ToolResult, ToolUse
3937
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
4038
from .streaming import stream_messages
4139

@@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
212210
if stop_reason == "max_tokens":
213211
"""
214212
Handle max_tokens limit reached by the model.
215-
213+
216214
When the model reaches its maximum token limit, this represents a potentially unrecoverable
217215
state where the model's response was truncated. By default, Strands fails hard with an
218216
MaxTokensReachedException to maintain consistency with other failure types.
@@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
306304
recursive_trace.end()
307305

308306

309-
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator:
310-
"""Process a tool invocation.
311-
312-
Looks up the tool in the registry and streams it with the provided parameters.
313-
314-
Args:
315-
agent: The agent for which the tool is being executed.
316-
tool_use: The tool object to process, containing name and parameters.
317-
invocation_state: Context for the tool invocation, including agent state.
318-
319-
Yields:
320-
Tool events with the last being the tool result.
321-
"""
322-
logger.debug("tool_use=<%s> | streaming", tool_use)
323-
tool_name = tool_use["name"]
324-
325-
# Get the tool info
326-
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
327-
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
328-
329-
# Add standard arguments to invocation_state for Python tools
330-
invocation_state.update(
331-
{
332-
"model": agent.model,
333-
"system_prompt": agent.system_prompt,
334-
"messages": agent.messages,
335-
"tool_config": ToolConfig( # for backwards compatability
336-
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
337-
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
338-
),
339-
}
340-
)
341-
342-
before_event = agent.hooks.invoke_callbacks(
343-
BeforeToolInvocationEvent(
344-
agent=agent,
345-
selected_tool=tool_func,
346-
tool_use=tool_use,
347-
invocation_state=invocation_state,
348-
)
349-
)
350-
351-
try:
352-
selected_tool = before_event.selected_tool
353-
tool_use = before_event.tool_use
354-
invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook
355-
356-
# Check if tool exists
357-
if not selected_tool:
358-
if tool_func == selected_tool:
359-
logger.error(
360-
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
361-
tool_name,
362-
list(agent.tool_registry.registry.keys()),
363-
)
364-
else:
365-
logger.debug(
366-
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
367-
tool_name,
368-
str(tool_use.get("toolUseId")),
369-
)
370-
371-
result: ToolResult = {
372-
"toolUseId": str(tool_use.get("toolUseId")),
373-
"status": "error",
374-
"content": [{"text": f"Unknown tool: {tool_name}"}],
375-
}
376-
# for every Before event call, we need to have an AfterEvent call
377-
after_event = agent.hooks.invoke_callbacks(
378-
AfterToolInvocationEvent(
379-
agent=agent,
380-
selected_tool=selected_tool,
381-
tool_use=tool_use,
382-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
383-
result=result,
384-
)
385-
)
386-
yield after_event.result
387-
return
388-
389-
async for event in selected_tool.stream(tool_use, invocation_state):
390-
yield event
391-
392-
result = event
393-
394-
after_event = agent.hooks.invoke_callbacks(
395-
AfterToolInvocationEvent(
396-
agent=agent,
397-
selected_tool=selected_tool,
398-
tool_use=tool_use,
399-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
400-
result=result,
401-
)
402-
)
403-
yield after_event.result
404-
405-
except Exception as e:
406-
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
407-
error_result: ToolResult = {
408-
"toolUseId": str(tool_use.get("toolUseId")),
409-
"status": "error",
410-
"content": [{"text": f"Error: {str(e)}"}],
411-
}
412-
after_event = agent.hooks.invoke_callbacks(
413-
AfterToolInvocationEvent(
414-
agent=agent,
415-
selected_tool=selected_tool,
416-
tool_use=tool_use,
417-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
418-
result=error_result,
419-
exception=e,
420-
)
421-
)
422-
yield after_event.result
423-
424-
425307
async def _handle_tool_execution(
426308
stop_reason: StopReason,
427309
message: Message,
@@ -431,18 +313,12 @@ async def _handle_tool_execution(
431313
cycle_start_time: float,
432314
invocation_state: dict[str, Any],
433315
) -> AsyncGenerator[dict[str, Any], None]:
434-
tool_uses: list[ToolUse] = []
435-
tool_results: list[ToolResult] = []
436-
invalid_tool_use_ids: list[str] = []
437-
438-
"""
439-
Handles the execution of tools requested by the model during an event loop cycle.
316+
"""Handles the execution of tools requested by the model during an event loop cycle.
440317
441318
Args:
442319
stop_reason: The reason the model stopped generating.
443320
message: The message from the model that may contain tool use requests.
444-
event_loop_metrics: Metrics tracking object for the event loop.
445-
event_loop_parent_span: Span for the parent of this event loop.
321+
agent: Agent for which tools are being executed.
446322
cycle_trace: Trace object for the current event loop cycle.
447323
cycle_span: Span object for tracing the cycle (type may vary).
448324
cycle_start_time: Start time of the current cycle.
@@ -456,24 +332,25 @@ async def _handle_tool_execution(
456332
- The updated event loop metrics,
457333
- The updated request state.
458334
"""
459-
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
335+
tool_uses: list[ToolUse] = []
336+
tool_results: list[ToolResult] = []
337+
invalid_tool_use_ids: list[str] = []
460338

339+
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
340+
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
461341
if not tool_uses:
462342
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
463343
return
464344

465-
def tool_handler(tool_use: ToolUse) -> ToolGenerator:
466-
return run_tool(agent, tool_use, invocation_state)
467-
468-
tool_events = run_tools(
469-
handler=tool_handler,
470-
tool_uses=tool_uses,
471-
event_loop_metrics=agent.event_loop_metrics,
472-
invalid_tool_use_ids=invalid_tool_use_ids,
473-
tool_results=tool_results,
474-
cycle_trace=cycle_trace,
475-
parent_span=cycle_span,
345+
invocation_state.update(
346+
{
347+
"event_loop_cycle_trace": cycle_trace,
348+
"event_loop_cycle_span": cycle_span,
349+
"tool_results": tool_results,
350+
}
476351
)
352+
353+
tool_events = agent.tool_executor.execute(agent, tool_uses, invocation_state)
477354
async for tool_event in tool_events:
478355
yield tool_event
479356

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Experimental tool features."""
2+
3+
from . import executors
4+
5+
__all__ = ["executors"]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Experimental tool executors."""
2+
3+
from . import executor
4+
from .executor import Executor
5+
6+
__all__ = ["executor", "Executor"]

0 commit comments

Comments
 (0)