Skip to content

Commit c18ef93

Browse files
authored
tool executors (#658)
1 parent e4879e1 commit c18ef93

File tree

18 files changed

+985
-1020
lines changed

18 files changed

+985
-1020
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pydantic import BaseModel
2121

2222
from .. import _identifier
23-
from ..event_loop.event_loop import event_loop_cycle, run_tool
23+
from ..event_loop.event_loop import event_loop_cycle
2424
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2525
from ..hooks import (
2626
AfterInvocationEvent,
@@ -35,6 +35,8 @@
3535
from ..session.session_manager import SessionManager
3636
from ..telemetry.metrics import EventLoopMetrics
3737
from ..telemetry.tracer import get_tracer, serialize
38+
from ..tools.executors import ConcurrentToolExecutor
39+
from ..tools.executors._executor import ToolExecutor
3840
from ..tools.registry import ToolRegistry
3941
from ..tools.watcher import ToolWatcher
4042
from ..types.content import ContentBlock, Message, Messages
@@ -136,13 +138,14 @@ def caller(
136138
"name": normalized_name,
137139
"input": kwargs.copy(),
138140
}
141+
tool_results: list[ToolResult] = []
142+
invocation_state = kwargs
139143

140144
async def acall() -> ToolResult:
141-
# Pass kwargs as invocation_state
142-
async for event in run_tool(self._agent, tool_use, kwargs):
145+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
143146
_ = event
144147

145-
return cast(ToolResult, event)
148+
return tool_results[0]
146149

147150
def tcall() -> ToolResult:
148151
return asyncio.run(acall())
@@ -208,6 +211,7 @@ def __init__(
208211
state: Optional[Union[AgentState, dict]] = None,
209212
hooks: Optional[list[HookProvider]] = None,
210213
session_manager: Optional[SessionManager] = None,
214+
tool_executor: Optional[ToolExecutor] = None,
211215
):
212216
"""Initialize the Agent with the specified configuration.
213217
@@ -250,6 +254,7 @@ def __init__(
250254
Defaults to None.
251255
session_manager: Manager for handling agent sessions including conversation history and state.
252256
If provided, enables session-based persistence and state management.
257+
tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.).
253258
254259
Raises:
255260
ValueError: If agent id contains path separators.
@@ -324,6 +329,8 @@ def __init__(
324329
if self._session_manager:
325330
self.hooks.add_hook(self._session_manager)
326331

332+
self.tool_executor = tool_executor or ConcurrentToolExecutor()
333+
327334
if hooks:
328335
for hook in hooks:
329336
self.hooks.add_hook(hook)

src/strands/event_loop/event_loop.py

Lines changed: 13 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@
1111
import logging
1212
import time
1313
import uuid
14-
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
14+
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

1616
from opentelemetry import trace as trace_api
1717

1818
from ..experimental.hooks import (
1919
AfterModelInvocationEvent,
20-
AfterToolInvocationEvent,
2120
BeforeModelInvocationEvent,
22-
BeforeToolInvocationEvent,
2321
)
2422
from ..hooks import (
2523
MessageAddedEvent,
2624
)
2725
from ..telemetry.metrics import Trace
2826
from ..telemetry.tracer import get_tracer
29-
from ..tools.executor import run_tools, validate_and_prepare_tools
27+
from ..tools._validator import validate_and_prepare_tools
3028
from ..types.content import Message
3129
from ..types.exceptions import (
3230
ContextWindowOverflowException,
@@ -35,7 +33,7 @@
3533
ModelThrottledException,
3634
)
3735
from ..types.streaming import Metrics, StopReason
38-
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
36+
from ..types.tools import ToolResult, ToolUse
3937
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
4038
from .streaming import stream_messages
4139

@@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
212210
if stop_reason == "max_tokens":
213211
"""
214212
Handle max_tokens limit reached by the model.
215-
213+
216214
When the model reaches its maximum token limit, this represents a potentially unrecoverable
217215
state where the model's response was truncated. By default, Strands fails hard with an
218216
MaxTokensReachedException to maintain consistency with other failure types.
@@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
306304
recursive_trace.end()
307305

308306

309-
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator:
310-
"""Process a tool invocation.
311-
312-
Looks up the tool in the registry and streams it with the provided parameters.
313-
314-
Args:
315-
agent: The agent for which the tool is being executed.
316-
tool_use: The tool object to process, containing name and parameters.
317-
invocation_state: Context for the tool invocation, including agent state.
318-
319-
Yields:
320-
Tool events with the last being the tool result.
321-
"""
322-
logger.debug("tool_use=<%s> | streaming", tool_use)
323-
tool_name = tool_use["name"]
324-
325-
# Get the tool info
326-
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
327-
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
328-
329-
# Add standard arguments to invocation_state for Python tools
330-
invocation_state.update(
331-
{
332-
"model": agent.model,
333-
"system_prompt": agent.system_prompt,
334-
"messages": agent.messages,
335-
"tool_config": ToolConfig( # for backwards compatability
336-
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
337-
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
338-
),
339-
}
340-
)
341-
342-
before_event = agent.hooks.invoke_callbacks(
343-
BeforeToolInvocationEvent(
344-
agent=agent,
345-
selected_tool=tool_func,
346-
tool_use=tool_use,
347-
invocation_state=invocation_state,
348-
)
349-
)
350-
351-
try:
352-
selected_tool = before_event.selected_tool
353-
tool_use = before_event.tool_use
354-
invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook
355-
356-
# Check if tool exists
357-
if not selected_tool:
358-
if tool_func == selected_tool:
359-
logger.error(
360-
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
361-
tool_name,
362-
list(agent.tool_registry.registry.keys()),
363-
)
364-
else:
365-
logger.debug(
366-
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
367-
tool_name,
368-
str(tool_use.get("toolUseId")),
369-
)
370-
371-
result: ToolResult = {
372-
"toolUseId": str(tool_use.get("toolUseId")),
373-
"status": "error",
374-
"content": [{"text": f"Unknown tool: {tool_name}"}],
375-
}
376-
# for every Before event call, we need to have an AfterEvent call
377-
after_event = agent.hooks.invoke_callbacks(
378-
AfterToolInvocationEvent(
379-
agent=agent,
380-
selected_tool=selected_tool,
381-
tool_use=tool_use,
382-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
383-
result=result,
384-
)
385-
)
386-
yield after_event.result
387-
return
388-
389-
async for event in selected_tool.stream(tool_use, invocation_state):
390-
yield event
391-
392-
result = event
393-
394-
after_event = agent.hooks.invoke_callbacks(
395-
AfterToolInvocationEvent(
396-
agent=agent,
397-
selected_tool=selected_tool,
398-
tool_use=tool_use,
399-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
400-
result=result,
401-
)
402-
)
403-
yield after_event.result
404-
405-
except Exception as e:
406-
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
407-
error_result: ToolResult = {
408-
"toolUseId": str(tool_use.get("toolUseId")),
409-
"status": "error",
410-
"content": [{"text": f"Error: {str(e)}"}],
411-
}
412-
after_event = agent.hooks.invoke_callbacks(
413-
AfterToolInvocationEvent(
414-
agent=agent,
415-
selected_tool=selected_tool,
416-
tool_use=tool_use,
417-
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
418-
result=error_result,
419-
exception=e,
420-
)
421-
)
422-
yield after_event.result
423-
424-
425307
async def _handle_tool_execution(
426308
stop_reason: StopReason,
427309
message: Message,
@@ -431,18 +313,12 @@ async def _handle_tool_execution(
431313
cycle_start_time: float,
432314
invocation_state: dict[str, Any],
433315
) -> AsyncGenerator[dict[str, Any], None]:
434-
tool_uses: list[ToolUse] = []
435-
tool_results: list[ToolResult] = []
436-
invalid_tool_use_ids: list[str] = []
437-
438-
"""
439-
Handles the execution of tools requested by the model during an event loop cycle.
316+
"""Handles the execution of tools requested by the model during an event loop cycle.
440317
441318
Args:
442319
stop_reason: The reason the model stopped generating.
443320
message: The message from the model that may contain tool use requests.
444-
event_loop_metrics: Metrics tracking object for the event loop.
445-
event_loop_parent_span: Span for the parent of this event loop.
321+
agent: Agent for which tools are being executed.
446322
cycle_trace: Trace object for the current event loop cycle.
447323
cycle_span: Span object for tracing the cycle (type may vary).
448324
cycle_start_time: Start time of the current cycle.
@@ -456,23 +332,18 @@ async def _handle_tool_execution(
456332
- The updated event loop metrics,
457333
- The updated request state.
458334
"""
459-
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
335+
tool_uses: list[ToolUse] = []
336+
tool_results: list[ToolResult] = []
337+
invalid_tool_use_ids: list[str] = []
460338

339+
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
340+
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
461341
if not tool_uses:
462342
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
463343
return
464344

465-
def tool_handler(tool_use: ToolUse) -> ToolGenerator:
466-
return run_tool(agent, tool_use, invocation_state)
467-
468-
tool_events = run_tools(
469-
handler=tool_handler,
470-
tool_uses=tool_uses,
471-
event_loop_metrics=agent.event_loop_metrics,
472-
invalid_tool_use_ids=invalid_tool_use_ids,
473-
tool_results=tool_results,
474-
cycle_trace=cycle_trace,
475-
parent_span=cycle_span,
345+
tool_events = agent.tool_executor._execute(
346+
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
476347
)
477348
async for tool_event in tool_events:
478349
yield tool_event

src/strands/tools/_validator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Tool validation utilities."""
2+
3+
from ..tools.tools import InvalidToolUseNameException, validate_tool_use
4+
from ..types.content import Message
5+
from ..types.tools import ToolResult, ToolUse
6+
7+
8+
def validate_and_prepare_tools(
9+
message: Message,
10+
tool_uses: list[ToolUse],
11+
tool_results: list[ToolResult],
12+
invalid_tool_use_ids: list[str],
13+
) -> None:
14+
"""Validate tool uses and prepare them for execution.
15+
16+
Args:
17+
message: Current message.
18+
tool_uses: List to populate with tool uses.
19+
tool_results: List to populate with tool results for invalid tools.
20+
invalid_tool_use_ids: List to populate with invalid tool use IDs.
21+
"""
22+
# Extract tool uses from message
23+
for content in message["content"]:
24+
if isinstance(content, dict) and "toolUse" in content:
25+
tool_uses.append(content["toolUse"])
26+
27+
# Validate tool uses
28+
# Avoid modifying original `tool_uses` variable during iteration
29+
tool_uses_copy = tool_uses.copy()
30+
for tool in tool_uses_copy:
31+
try:
32+
validate_tool_use(tool)
33+
except InvalidToolUseNameException as e:
34+
# Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
35+
tool_uses.remove(tool)
36+
tool["name"] = "INVALID_TOOL_NAME"
37+
invalid_tool_use_ids.append(tool["toolUseId"])
38+
tool_uses.append(tool)
39+
tool_results.append(
40+
{
41+
"toolUseId": tool["toolUseId"],
42+
"status": "error",
43+
"content": [{"text": f"Error: {str(e)}"}],
44+
}
45+
)

0 commit comments

Comments
 (0)