-
Notifications
You must be signed in to change notification settings - Fork 314
tool executors #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
tool executors #658
Changes from all commits
97a062e
8a86a29
45198c3
c95154f
01327c2
7cc16e6
4398086
b3249cc
72e4238
946035c
5740118
56b51de
43af1ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pydantic import BaseModel | ||
|
||
from .. import _identifier | ||
from ..event_loop.event_loop import event_loop_cycle, run_tool | ||
from ..event_loop.event_loop import event_loop_cycle | ||
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler | ||
from ..hooks import ( | ||
AfterInvocationEvent, | ||
|
@@ -35,6 +35,8 @@ | |
from ..session.session_manager import SessionManager | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..telemetry.tracer import get_tracer, serialize | ||
from ..tools.executors import ConcurrentToolExecutor | ||
from ..tools.executors._executor import ToolExecutor | ||
from ..tools.registry import ToolRegistry | ||
from ..tools.watcher import ToolWatcher | ||
from ..types.content import ContentBlock, Message, Messages | ||
|
@@ -136,13 +138,14 @@ def caller( | |
"name": normalized_name, | ||
"input": kwargs.copy(), | ||
} | ||
tool_results: list[ToolResult] = [] | ||
invocation_state = kwargs | ||
|
||
async def acall() -> ToolResult: | ||
# Pass kwargs as invocation_state | ||
async for event in run_tool(self._agent, tool_use, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per discussion, member methods of the executor classes are also marked private to discourage customer use. We want to first cement the interface before exposing. |
||
_ = event | ||
|
||
return cast(ToolResult, event) | ||
return tool_results[0] | ||
|
||
def tcall() -> ToolResult: | ||
return asyncio.run(acall()) | ||
|
@@ -208,6 +211,7 @@ def __init__( | |
state: Optional[Union[AgentState, dict]] = None, | ||
hooks: Optional[list[HookProvider]] = None, | ||
session_manager: Optional[SessionManager] = None, | ||
tool_executor: Optional[ToolExecutor] = None, | ||
): | ||
"""Initialize the Agent with the specified configuration. | ||
|
||
|
@@ -250,6 +254,7 @@ def __init__( | |
Defaults to None. | ||
session_manager: Manager for handling agent sessions including conversation history and state. | ||
If provided, enables session-based persistence and state management. | ||
tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.). | ||
|
||
Raises: | ||
ValueError: If agent id contains path separators. | ||
|
@@ -324,6 +329,8 @@ def __init__( | |
if self._session_manager: | ||
self.hooks.add_hook(self._session_manager) | ||
|
||
self.tool_executor = tool_executor or ConcurrentToolExecutor() | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if hooks: | ||
for hook in hooks: | ||
self.hooks.add_hook(hook) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,22 +11,20 @@ | |
import logging | ||
import time | ||
import uuid | ||
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast | ||
from typing import TYPE_CHECKING, Any, AsyncGenerator | ||
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..experimental.hooks import ( | ||
AfterModelInvocationEvent, | ||
AfterToolInvocationEvent, | ||
BeforeModelInvocationEvent, | ||
BeforeToolInvocationEvent, | ||
) | ||
from ..hooks import ( | ||
MessageAddedEvent, | ||
) | ||
from ..telemetry.metrics import Trace | ||
from ..telemetry.tracer import get_tracer | ||
from ..tools.executor import run_tools, validate_and_prepare_tools | ||
from ..tools._validator import validate_and_prepare_tools | ||
from ..types.content import Message | ||
from ..types.exceptions import ( | ||
ContextWindowOverflowException, | ||
|
@@ -35,7 +33,7 @@ | |
ModelThrottledException, | ||
) | ||
from ..types.streaming import Metrics, StopReason | ||
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse | ||
from ..types.tools import ToolResult, ToolUse | ||
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached | ||
from .streaming import stream_messages | ||
|
||
|
@@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> | |
if stop_reason == "max_tokens": | ||
""" | ||
Handle max_tokens limit reached by the model. | ||
|
||
When the model reaches its maximum token limit, this represents a potentially unrecoverable | ||
state where the model's response was truncated. By default, Strands fails hard with an | ||
MaxTokensReachedException to maintain consistency with other failure types. | ||
|
@@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - | |
recursive_trace.end() | ||
|
||
|
||
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved logic into |
||
"""Process a tool invocation. | ||
|
||
Looks up the tool in the registry and streams it with the provided parameters. | ||
|
||
Args: | ||
agent: The agent for which the tool is being executed. | ||
tool_use: The tool object to process, containing name and parameters. | ||
invocation_state: Context for the tool invocation, including agent state. | ||
|
||
Yields: | ||
Tool events with the last being the tool result. | ||
""" | ||
logger.debug("tool_use=<%s> | streaming", tool_use) | ||
tool_name = tool_use["name"] | ||
|
||
# Get the tool info | ||
tool_info = agent.tool_registry.dynamic_tools.get(tool_name) | ||
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) | ||
|
||
# Add standard arguments to invocation_state for Python tools | ||
invocation_state.update( | ||
{ | ||
"model": agent.model, | ||
"system_prompt": agent.system_prompt, | ||
"messages": agent.messages, | ||
"tool_config": ToolConfig( # for backwards compatability | ||
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], | ||
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), | ||
), | ||
} | ||
) | ||
|
||
before_event = agent.hooks.invoke_callbacks( | ||
BeforeToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=tool_func, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, | ||
) | ||
) | ||
|
||
try: | ||
selected_tool = before_event.selected_tool | ||
tool_use = before_event.tool_use | ||
invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook | ||
|
||
# Check if tool exists | ||
if not selected_tool: | ||
if tool_func == selected_tool: | ||
logger.error( | ||
"tool_name=<%s>, available_tools=<%s> | tool not found in registry", | ||
tool_name, | ||
list(agent.tool_registry.registry.keys()), | ||
) | ||
else: | ||
logger.debug( | ||
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", | ||
tool_name, | ||
str(tool_use.get("toolUseId")), | ||
) | ||
|
||
result: ToolResult = { | ||
"toolUseId": str(tool_use.get("toolUseId")), | ||
"status": "error", | ||
"content": [{"text": f"Unknown tool: {tool_name}"}], | ||
} | ||
# for every Before event call, we need to have an AfterEvent call | ||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=result, | ||
) | ||
) | ||
yield after_event.result | ||
return | ||
|
||
async for event in selected_tool.stream(tool_use, invocation_state): | ||
yield event | ||
|
||
result = event | ||
|
||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=result, | ||
) | ||
) | ||
yield after_event.result | ||
|
||
except Exception as e: | ||
logger.exception("tool_name=<%s> | failed to process tool", tool_name) | ||
error_result: ToolResult = { | ||
"toolUseId": str(tool_use.get("toolUseId")), | ||
"status": "error", | ||
"content": [{"text": f"Error: {str(e)}"}], | ||
} | ||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=error_result, | ||
exception=e, | ||
) | ||
) | ||
yield after_event.result | ||
|
||
|
||
async def _handle_tool_execution( | ||
stop_reason: StopReason, | ||
message: Message, | ||
|
@@ -431,18 +313,12 @@ async def _handle_tool_execution( | |
cycle_start_time: float, | ||
invocation_state: dict[str, Any], | ||
) -> AsyncGenerator[dict[str, Any], None]: | ||
tool_uses: list[ToolUse] = [] | ||
tool_results: list[ToolResult] = [] | ||
invalid_tool_use_ids: list[str] = [] | ||
|
||
""" | ||
Handles the execution of tools requested by the model during an event loop cycle. | ||
"""Handles the execution of tools requested by the model during an event loop cycle. | ||
|
||
Args: | ||
stop_reason: The reason the model stopped generating. | ||
message: The message from the model that may contain tool use requests. | ||
event_loop_metrics: Metrics tracking object for the event loop. | ||
event_loop_parent_span: Span for the parent of this event loop. | ||
agent: Agent for which tools are being executed. | ||
cycle_trace: Trace object for the current event loop cycle. | ||
cycle_span: Span object for tracing the cycle (type may vary). | ||
cycle_start_time: Start time of the current cycle. | ||
|
@@ -456,23 +332,18 @@ async def _handle_tool_execution( | |
- The updated event loop metrics, | ||
- The updated request state. | ||
""" | ||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses: list[ToolUse] = [] | ||
tool_results: list[ToolResult] = [] | ||
invalid_tool_use_ids: list[str] = [] | ||
|
||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional nit: any reason this needs to be modified in place? I know we do it in the code base but I find it harder to follow There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could follow up on this (and a lot of code in this PR). For now though, I figured I would just try to copy and paste as much as possible rather than fully alter the existing logic to simplify the PR. In other words, this PR more so focuses on code organization and placement rather than the logic. |
||
if not tool_uses: | ||
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} | ||
return | ||
|
||
def tool_handler(tool_use: ToolUse) -> ToolGenerator: | ||
return run_tool(agent, tool_use, invocation_state) | ||
|
||
tool_events = run_tools( | ||
handler=tool_handler, | ||
tool_uses=tool_uses, | ||
event_loop_metrics=agent.event_loop_metrics, | ||
invalid_tool_use_ids=invalid_tool_use_ids, | ||
tool_results=tool_results, | ||
cycle_trace=cycle_trace, | ||
parent_span=cycle_span, | ||
tool_events = agent.tool_executor._execute( | ||
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state | ||
) | ||
async for tool_event in tool_events: | ||
yield tool_event | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Tool validation utilities.""" | ||
|
||
from ..tools.tools import InvalidToolUseNameException, validate_tool_use | ||
from ..types.content import Message | ||
from ..types.tools import ToolResult, ToolUse | ||
|
||
|
||
def validate_and_prepare_tools( | ||
message: Message, | ||
tool_uses: list[ToolUse], | ||
tool_results: list[ToolResult], | ||
invalid_tool_use_ids: list[str], | ||
) -> None: | ||
"""Validate tool uses and prepare them for execution. | ||
|
||
Args: | ||
message: Current message. | ||
tool_uses: List to populate with tool uses. | ||
tool_results: List to populate with tool results for invalid tools. | ||
invalid_tool_use_ids: List to populate with invalid tool use IDs. | ||
""" | ||
# Extract tool uses from message | ||
for content in message["content"]: | ||
if isinstance(content, dict) and "toolUse" in content: | ||
tool_uses.append(content["toolUse"]) | ||
|
||
# Validate tool uses | ||
# Avoid modifying original `tool_uses` variable during iteration | ||
tool_uses_copy = tool_uses.copy() | ||
for tool in tool_uses_copy: | ||
try: | ||
validate_tool_use(tool) | ||
except InvalidToolUseNameException as e: | ||
# Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context | ||
tool_uses.remove(tool) | ||
tool["name"] = "INVALID_TOOL_NAME" | ||
invalid_tool_use_ids.append(tool["toolUseId"]) | ||
tool_uses.append(tool) | ||
tool_results.append( | ||
{ | ||
"toolUseId": tool["toolUseId"], | ||
"status": "error", | ||
"content": [{"text": f"Error: {str(e)}"}], | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the PR with:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add a section about the expected follow-ups that are decided on - if this is a stepping stone to the long-term solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "Usage" and "Follow Up" sections to the overview.