-
Notifications
You must be signed in to change notification settings - Fork 314
tool executors #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
tool executors #658
Changes from 4 commits
97a062e
8a86a29
45198c3
c95154f
01327c2
7cc16e6
4398086
b3249cc
72e4238
946035c
5740118
56b51de
43af1ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,8 @@ | |
from opentelemetry import trace as trace_api | ||
from pydantic import BaseModel | ||
|
||
from ..event_loop.event_loop import event_loop_cycle, run_tool | ||
from ..event_loop.event_loop import event_loop_cycle | ||
from ..experimental.tools.executors import Executor as ToolExecutor | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler | ||
from ..hooks import ( | ||
AfterInvocationEvent, | ||
|
@@ -34,6 +35,7 @@ | |
from ..session.session_manager import SessionManager | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..telemetry.tracer import get_tracer | ||
from ..tools import executors as tool_executors | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from ..tools.registry import ToolRegistry | ||
from ..tools.watcher import ToolWatcher | ||
from ..types.content import ContentBlock, Message, Messages | ||
|
@@ -135,13 +137,16 @@ def caller( | |
"name": normalized_name, | ||
"input": kwargs.copy(), | ||
} | ||
tool_results: list[ToolResult] = [] | ||
invocation_state = kwargs | ||
|
||
tool_executor = tool_executors.sequential.Executor(skip_tracing=True) | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
async def acall() -> ToolResult: | ||
# Pass kwargs as invocation_state | ||
async for event in run_tool(self._agent, tool_use, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
async for event in tool_executor.stream(self._agent, tool_use, tool_results, invocation_state): | ||
_ = event | ||
|
||
return cast(ToolResult, event) | ||
return tool_results[0] | ||
|
||
def tcall() -> ToolResult: | ||
return asyncio.run(acall()) | ||
|
@@ -207,6 +212,7 @@ def __init__( | |
state: Optional[Union[AgentState, dict]] = None, | ||
hooks: Optional[list[HookProvider]] = None, | ||
session_manager: Optional[SessionManager] = None, | ||
tool_executor: Optional[ToolExecutor] = None, | ||
): | ||
"""Initialize the Agent with the specified configuration. | ||
|
@@ -249,6 +255,7 @@ def __init__( | |
Defaults to None. | ||
session_manager: Manager for handling agent sessions including conversation history and state. | ||
If provided, enables session-based persistence and state management. | ||
tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.). | ||
""" | ||
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model | ||
self.messages = messages if messages is not None else [] | ||
|
@@ -320,6 +327,8 @@ def __init__( | |
if self._session_manager: | ||
self.hooks.add_hook(self._session_manager) | ||
|
||
self.tool_executor = tool_executor or tool_executors.concurrent.Executor() | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if hooks: | ||
for hook in hooks: | ||
self.hooks.add_hook(hook) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,22 +11,20 @@ | |
import logging | ||
import time | ||
import uuid | ||
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast | ||
from typing import TYPE_CHECKING, Any, AsyncGenerator | ||
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..experimental.hooks import ( | ||
AfterModelInvocationEvent, | ||
AfterToolInvocationEvent, | ||
BeforeModelInvocationEvent, | ||
BeforeToolInvocationEvent, | ||
) | ||
from ..hooks import ( | ||
MessageAddedEvent, | ||
) | ||
from ..telemetry.metrics import Trace | ||
from ..telemetry.tracer import get_tracer | ||
from ..tools.executor import run_tools, validate_and_prepare_tools | ||
from ..tools.validator import validate_and_prepare_tools | ||
from ..types.content import Message | ||
from ..types.exceptions import ( | ||
ContextWindowOverflowException, | ||
|
@@ -35,7 +33,7 @@ | |
ModelThrottledException, | ||
) | ||
from ..types.streaming import Metrics, StopReason | ||
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse | ||
from ..types.tools import ToolResult, ToolUse | ||
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached | ||
from .streaming import stream_messages | ||
|
||
|
@@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> | |
if stop_reason == "max_tokens": | ||
""" | ||
Handle max_tokens limit reached by the model. | ||
|
||
When the model reaches its maximum token limit, this represents a potentially unrecoverable | ||
state where the model's response was truncated. By default, Strands fails hard with an | ||
MaxTokensReachedException to maintain consistency with other failure types. | ||
|
@@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - | |
recursive_trace.end() | ||
|
||
|
||
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved logic into |
||
"""Process a tool invocation. | ||
|
||
Looks up the tool in the registry and streams it with the provided parameters. | ||
|
||
Args: | ||
agent: The agent for which the tool is being executed. | ||
tool_use: The tool object to process, containing name and parameters. | ||
invocation_state: Context for the tool invocation, including agent state. | ||
|
||
Yields: | ||
Tool events with the last being the tool result. | ||
""" | ||
logger.debug("tool_use=<%s> | streaming", tool_use) | ||
tool_name = tool_use["name"] | ||
|
||
# Get the tool info | ||
tool_info = agent.tool_registry.dynamic_tools.get(tool_name) | ||
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) | ||
|
||
# Add standard arguments to invocation_state for Python tools | ||
invocation_state.update( | ||
{ | ||
"model": agent.model, | ||
"system_prompt": agent.system_prompt, | ||
"messages": agent.messages, | ||
"tool_config": ToolConfig( # for backwards compatability | ||
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], | ||
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), | ||
), | ||
} | ||
) | ||
|
||
before_event = agent.hooks.invoke_callbacks( | ||
BeforeToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=tool_func, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, | ||
) | ||
) | ||
|
||
try: | ||
selected_tool = before_event.selected_tool | ||
tool_use = before_event.tool_use | ||
invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook | ||
|
||
# Check if tool exists | ||
if not selected_tool: | ||
if tool_func == selected_tool: | ||
logger.error( | ||
"tool_name=<%s>, available_tools=<%s> | tool not found in registry", | ||
tool_name, | ||
list(agent.tool_registry.registry.keys()), | ||
) | ||
else: | ||
logger.debug( | ||
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", | ||
tool_name, | ||
str(tool_use.get("toolUseId")), | ||
) | ||
|
||
result: ToolResult = { | ||
"toolUseId": str(tool_use.get("toolUseId")), | ||
"status": "error", | ||
"content": [{"text": f"Unknown tool: {tool_name}"}], | ||
} | ||
# for every Before event call, we need to have an AfterEvent call | ||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=result, | ||
) | ||
) | ||
yield after_event.result | ||
return | ||
|
||
async for event in selected_tool.stream(tool_use, invocation_state): | ||
yield event | ||
|
||
result = event | ||
|
||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=result, | ||
) | ||
) | ||
yield after_event.result | ||
|
||
except Exception as e: | ||
logger.exception("tool_name=<%s> | failed to process tool", tool_name) | ||
error_result: ToolResult = { | ||
"toolUseId": str(tool_use.get("toolUseId")), | ||
"status": "error", | ||
"content": [{"text": f"Error: {str(e)}"}], | ||
} | ||
after_event = agent.hooks.invoke_callbacks( | ||
AfterToolInvocationEvent( | ||
agent=agent, | ||
selected_tool=selected_tool, | ||
tool_use=tool_use, | ||
invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks | ||
result=error_result, | ||
exception=e, | ||
) | ||
) | ||
yield after_event.result | ||
|
||
|
||
async def _handle_tool_execution( | ||
stop_reason: StopReason, | ||
message: Message, | ||
|
@@ -431,18 +313,12 @@ async def _handle_tool_execution( | |
cycle_start_time: float, | ||
invocation_state: dict[str, Any], | ||
) -> AsyncGenerator[dict[str, Any], None]: | ||
tool_uses: list[ToolUse] = [] | ||
tool_results: list[ToolResult] = [] | ||
invalid_tool_use_ids: list[str] = [] | ||
|
||
""" | ||
Handles the execution of tools requested by the model during an event loop cycle. | ||
"""Handles the execution of tools requested by the model during an event loop cycle. | ||
|
||
Args: | ||
stop_reason: The reason the model stopped generating. | ||
message: The message from the model that may contain tool use requests. | ||
event_loop_metrics: Metrics tracking object for the event loop. | ||
event_loop_parent_span: Span for the parent of this event loop. | ||
agent: Agent for which tools are being executed. | ||
cycle_trace: Trace object for the current event loop cycle. | ||
cycle_span: Span object for tracing the cycle (type may vary). | ||
cycle_start_time: Start time of the current cycle. | ||
|
@@ -456,24 +332,24 @@ async def _handle_tool_execution( | |
- The updated event loop metrics, | ||
- The updated request state. | ||
""" | ||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses: list[ToolUse] = [] | ||
tool_results: list[ToolResult] = [] | ||
invalid_tool_use_ids: list[str] = [] | ||
|
||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional nit: any reason this needs to be modified in place? I know we do it in the code base but I find it harder to follow There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could follow up on this (and a lot of code in this PR). For now though, I figured I would just try to copy and paste as much as possible rather than fully alter the existing logic to simplify the PR. In other words, this PR more so focuses on code organization and placement rather than the logic. |
||
if not tool_uses: | ||
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} | ||
return | ||
|
||
def tool_handler(tool_use: ToolUse) -> ToolGenerator: | ||
return run_tool(agent, tool_use, invocation_state) | ||
|
||
tool_events = run_tools( | ||
handler=tool_handler, | ||
tool_uses=tool_uses, | ||
event_loop_metrics=agent.event_loop_metrics, | ||
invalid_tool_use_ids=invalid_tool_use_ids, | ||
tool_results=tool_results, | ||
cycle_trace=cycle_trace, | ||
parent_span=cycle_span, | ||
invocation_state.update( | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
"event_loop_cycle_trace": cycle_trace, | ||
"event_loop_cycle_span": cycle_span, | ||
} | ||
) | ||
|
||
tool_events = agent.tool_executor.execute(agent, tool_uses, tool_results, invocation_state) | ||
async for tool_event in tool_events: | ||
yield tool_event | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Experimental tool features.""" | ||
|
||
from . import executors | ||
|
||
__all__ = ["executors"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Experimental tool executors.""" | ||
|
||
from . import executor | ||
from .executor import Executor | ||
|
||
__all__ = ["executor", "Executor"] | ||
pgrayy marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the PR with:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add a section about the expected follow-ups that are decided on - if this is a stepping stone to the long-term solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "Usage" and "Follow Up" sections to the overview.