diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5150060c6..adc554bf4 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,7 +20,7 @@ from pydantic import BaseModel from .. import _identifier -from ..event_loop.event_loop import event_loop_cycle, run_tool +from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -35,6 +35,8 @@ from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize +from ..tools.executors import ConcurrentToolExecutor +from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages @@ -136,13 +138,14 @@ def caller( "name": normalized_name, "input": kwargs.copy(), } + tool_results: list[ToolResult] = [] + invocation_state = kwargs async def acall() -> ToolResult: - # Pass kwargs as invocation_state - async for event in run_tool(self._agent, tool_use, kwargs): + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): _ = event - return cast(ToolResult, event) + return tool_results[0] def tcall() -> ToolResult: return asyncio.run(acall()) @@ -208,6 +211,7 @@ def __init__( state: Optional[Union[AgentState, dict]] = None, hooks: Optional[list[HookProvider]] = None, session_manager: Optional[SessionManager] = None, + tool_executor: Optional[ToolExecutor] = None, ): """Initialize the Agent with the specified configuration. @@ -250,6 +254,7 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.). Raises: ValueError: If agent id contains path separators. @@ -324,6 +329,8 @@ def __init__( if self._session_manager: self.hooks.add_hook(self._session_manager) + self.tool_executor = tool_executor or ConcurrentToolExecutor() + if hooks: for hook in hooks: self.hooks.add_hook(hook) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b36f73155..524ecc3e8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,22 +11,20 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator from opentelemetry import trace as trace_api from ..experimental.hooks import ( AfterModelInvocationEvent, - AfterToolInvocationEvent, BeforeModelInvocationEvent, - BeforeToolInvocationEvent, ) from ..hooks import ( MessageAddedEvent, ) from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer -from ..tools.executor import run_tools, validate_and_prepare_tools +from ..tools._validator import validate_and_prepare_tools from ..types.content import Message from ..types.exceptions import ( ContextWindowOverflowException, @@ -35,7 +33,7 @@ ModelThrottledException, ) from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages @@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. - + When the model reaches its maximum token limit, this represents a potentially unrecoverable state where the model's response was truncated. By default, Strands fails hard with an MaxTokensReachedException to maintain consistency with other failure types. @@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace.end() -async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: - """Process a tool invocation. - - Looks up the tool in the registry and streams it with the provided parameters. - - Args: - agent: The agent for which the tool is being executed. - tool_use: The tool object to process, containing name and parameters. - invocation_state: Context for the tool invocation, including agent state. - - Yields: - Tool events with the last being the tool result. - """ - logger.debug("tool_use=<%s> | streaming", tool_use) - tool_name = tool_use["name"] - - # Get the tool info - tool_info = agent.tool_registry.dynamic_tools.get(tool_name) - tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) - - # Add standard arguments to invocation_state for Python tools - invocation_state.update( - { - "model": agent.model, - "system_prompt": agent.system_prompt, - "messages": agent.messages, - "tool_config": ToolConfig( # for backwards compatability - tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], - toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), - ), - } - ) - - before_event = agent.hooks.invoke_callbacks( - BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) - ) - - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook - - # Check if tool exists - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) - - result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], - } - # for every Before event call, we need to have an AfterEvent call - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=result, - ) - ) - yield after_event.result - return - - async for event in selected_tool.stream(tool_use, invocation_state): - yield event - - result = event - - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=result, - ) - ) - yield after_event.result - - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=error_result, - exception=e, - ) - ) - yield after_event.result - - async def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -431,18 +313,12 @@ async def _handle_tool_execution( cycle_start_time: float, invocation_state: dict[str, Any], ) -> AsyncGenerator[dict[str, Any], None]: - tool_uses: list[ToolUse] = [] - tool_results: list[ToolResult] = [] - invalid_tool_use_ids: list[str] = [] - - """ - Handles the execution of tools requested by the model during an event loop cycle. + """Handles the execution of tools requested by the model during an event loop cycle. Args: stop_reason: The reason the model stopped generating. message: The message from the model that may contain tool use requests. - event_loop_metrics: Metrics tracking object for the event loop. - event_loop_parent_span: Span for the parent of this event loop. + agent: Agent for which tools are being executed. cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle (type may vary). cycle_start_time: Start time of the current cycle. @@ -456,23 +332,18 @@ async def _handle_tool_execution( - The updated event loop metrics, - The updated request state. """ - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] if not tool_uses: yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} return - def tool_handler(tool_use: ToolUse) -> ToolGenerator: - return run_tool(agent, tool_use, invocation_state) - - tool_events = run_tools( - handler=tool_handler, - tool_uses=tool_uses, - event_loop_metrics=agent.event_loop_metrics, - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, + tool_events = agent.tool_executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: yield tool_event diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py new file mode 100644 index 000000000..77aa57e87 --- /dev/null +++ b/src/strands/tools/_validator.py @@ -0,0 +1,45 @@ +"""Tool validation utilities.""" + +from ..tools.tools import InvalidToolUseNameException, validate_tool_use +from ..types.content import Message +from ..types.tools import ToolResult, ToolUse + + +def validate_and_prepare_tools( + message: Message, + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + invalid_tool_use_ids: list[str], +) -> None: + """Validate tool uses and prepare them for execution. + + Args: + message: Current message. + tool_uses: List to populate with tool uses. + tool_results: List to populate with tool results for invalid tools. + invalid_tool_use_ids: List to populate with invalid tool use IDs. + """ + # Extract tool uses from message + for content in message["content"]: + if isinstance(content, dict) and "toolUse" in content: + tool_uses.append(content["toolUse"]) + + # Validate tool uses + # Avoid modifying original `tool_uses` variable during iteration + tool_uses_copy = tool_uses.copy() + for tool in tool_uses_copy: + try: + validate_tool_use(tool) + except InvalidToolUseNameException as e: + # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context + tool_uses.remove(tool) + tool["name"] = "INVALID_TOOL_NAME" + invalid_tool_use_ids.append(tool["toolUseId"]) + tool_uses.append(tool) + tool_results.append( + { + "toolUseId": tool["toolUseId"], + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } + ) diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py deleted file mode 100644 index d90f9a5aa..000000000 --- a/src/strands/tools/executor.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Tool execution functionality for the event loop.""" - -import asyncio -import logging -import time -from typing import Any, Optional, cast - -from opentelemetry import trace as trace_api - -from ..telemetry.metrics import EventLoopMetrics, Trace -from ..telemetry.tracer import get_tracer -from ..tools.tools import InvalidToolUseNameException, validate_tool_use -from ..types.content import Message -from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse - -logger = logging.getLogger(__name__) - - -async def run_tools( - handler: RunToolHandler, - tool_uses: list[ToolUse], - event_loop_metrics: EventLoopMetrics, - invalid_tool_use_ids: list[str], - tool_results: list[ToolResult], - cycle_trace: Trace, - parent_span: Optional[trace_api.Span] = None, -) -> ToolGenerator: - """Execute tools concurrently. - - Args: - handler: Tool handler processing function. - tool_uses: List of tool uses to execute. - event_loop_metrics: Metrics collection object. - invalid_tool_use_ids: List of invalid tool use IDs. - tool_results: List to populate with tool results. - cycle_trace: Parent trace for the current cycle. - parent_span: Parent span for the current cycle. - - Yields: - Events of the tool stream. Tool results are appended to `tool_results`. - """ - - async def work( - tool_use: ToolUse, - worker_id: int, - worker_queue: asyncio.Queue, - worker_event: asyncio.Event, - stop_event: object, - ) -> ToolResult: - tracer = get_tracer() - tool_call_span = tracer.start_tool_call_span(tool_use, parent_span) - - tool_name = tool_use["name"] - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) - tool_start_time = time.time() - with trace_api.use_span(tool_call_span): - try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) - await worker_event.wait() - worker_event.clear() - - result = cast(ToolResult, event) - finally: - worker_queue.put_nowait((worker_id, stop_event)) - - tool_success = result.get("status") == "success" - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - tracer.end_tool_call_span(tool_call_span, result) - - return result - - tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() - worker_events = [asyncio.Event() for _ in tool_uses] - stop_event = object() - - workers = [ - asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event)) - for worker_id, tool_use in enumerate(tool_uses) - ] - - worker_count = len(workers) - while worker_count: - worker_id, event = await worker_queue.get() - if event is stop_event: - worker_count -= 1 - continue - - yield event - worker_events[worker_id].set() - - tool_results.extend([worker.result() for worker in workers]) - - -def validate_and_prepare_tools( - message: Message, - tool_uses: list[ToolUse], - tool_results: list[ToolResult], - invalid_tool_use_ids: list[str], -) -> None: - """Validate tool uses and prepare them for execution. - - Args: - message: Current message. - tool_uses: List to populate with tool uses. - tool_results: List to populate with tool results for invalid tools. - invalid_tool_use_ids: List to populate with invalid tool use IDs. - """ - # Extract tool uses from message - for content in message["content"]: - if isinstance(content, dict) and "toolUse" in content: - tool_uses.append(content["toolUse"]) - - # Validate tool uses - # Avoid modifying original `tool_uses` variable during iteration - tool_uses_copy = tool_uses.copy() - for tool in tool_uses_copy: - try: - validate_tool_use(tool) - except InvalidToolUseNameException as e: - # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context - tool_uses.remove(tool) - tool["name"] = "INVALID_TOOL_NAME" - invalid_tool_use_ids.append(tool["toolUseId"]) - tool_uses.append(tool) - tool_results.append( - { - "toolUseId": tool["toolUseId"], - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } - ) diff --git a/src/strands/tools/executors/__init__.py b/src/strands/tools/executors/__init__.py new file mode 100644 index 000000000..c8be812e4 --- /dev/null +++ b/src/strands/tools/executors/__init__.py @@ -0,0 +1,16 @@ +"""Tool executors for the Strands SDK. + +This package provides different execution strategies for tools, allowing users to customize +how tools are executed (e.g., concurrent, sequential, with custom thread pools, etc.). +""" + +from . import concurrent, sequential +from .concurrent import ConcurrentToolExecutor +from .sequential import SequentialToolExecutor + +__all__ = [ + "ConcurrentToolExecutor", + "SequentialToolExecutor", + "concurrent", + "sequential", +] diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py new file mode 100644 index 000000000..9999b77fc --- /dev/null +++ b/src/strands/tools/executors/_executor.py @@ -0,0 +1,227 @@ +"""Abstract base class for tool executors. + +Tool executors are responsible for determining how tools are executed (e.g., concurrently, sequentially, with custom +thread pools, etc.). +""" + +import abc +import logging +import time +from typing import TYPE_CHECKING, Any, cast + +from opentelemetry import trace as trace_api + +from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from ...telemetry.metrics import Trace +from ...telemetry.tracer import get_tracer +from ...types.content import Message +from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + +logger = logging.getLogger(__name__) + + +class ToolExecutor(abc.ABC): + """Abstract base class for tool executors.""" + + @staticmethod + async def _stream( + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + """Stream tool events. + + This method adds additional logic to the stream invocation including: + + - Tool lookup and validation + - Before/after hook execution + - Tracing and metrics collection + - Error handling and recovery + + Args: + agent: The agent for which the tool is being executed. + tool_use: Metadata and inputs for the tool to be executed. + tool_results: List of tool results from each tool execution. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Tool events with the last being the tool result. + """ + logger.debug("tool_use=<%s> | streaming", tool_use) + tool_name = tool_use["name"] + + tool_info = agent.tool_registry.dynamic_tools.get(tool_name) + tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) + + invocation_state.update( + { + "model": agent.model, + "messages": agent.messages, + "system_prompt": agent.system_prompt, + "tool_config": ToolConfig( # for backwards compatibility + tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], + toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), + ), + } + ) + + before_event = agent.hooks.invoke_callbacks( + BeforeToolInvocationEvent( + agent=agent, + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + + try: + selected_tool = before_event.selected_tool + tool_use = before_event.tool_use + invocation_state = before_event.invocation_state + + if not selected_tool: + if tool_func == selected_tool: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + else: + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", + tool_name, + str(tool_use.get("toolUseId")), + ) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + return + + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): + yield event + + result = cast(ToolResult, event) + + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + error_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=error_result, + exception=e, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + + @staticmethod + async def _stream_with_trace( + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + """Execute tool with tracing and metrics collection. + + Args: + agent: The agent for which the tool is being executed. + tool_use: Metadata and inputs for the tool to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Tool events with the last being the tool result. + """ + tool_name = tool_use["name"] + + tracer = get_tracer() + + tool_call_span = tracer.start_tool_call_span(tool_use, cycle_span) + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) + tool_start_time = time.time() + + with trace_api.use_span(tool_call_span): + async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): + yield event + + result = cast(ToolResult, event) + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) + + tracer.end_tool_call_span(tool_call_span, result) + + @abc.abstractmethod + # pragma: no cover + def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute the given tools according to this executor's strategy. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + pass diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py new file mode 100644 index 000000000..7d5dd7fe7 --- /dev/null +++ b/src/strands/tools/executors/concurrent.py @@ -0,0 +1,113 @@ +"""Concurrent tool executor implementation.""" + +import asyncio +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ...telemetry.metrics import Trace +from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ._executor import ToolExecutor + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + + +class ConcurrentToolExecutor(ToolExecutor): + """Concurrent tool executor.""" + + @override + async def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute tools concurrently. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + task_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() + task_events = [asyncio.Event() for _ in tool_uses] + stop_event = object() + + tasks = [ + asyncio.create_task( + self._task( + agent, + tool_use, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + task_id, + task_queue, + task_events[task_id], + stop_event, + ) + ) + for task_id, tool_use in enumerate(tool_uses) + ] + + task_count = len(tasks) + while task_count: + task_id, event = await task_queue.get() + if event is stop_event: + task_count -= 1 + continue + + yield event + task_events[task_id].set() + + asyncio.gather(*tasks) + + async def _task( + self, + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + task_id: int, + task_queue: asyncio.Queue, + task_event: asyncio.Event, + stop_event: object, + ) -> None: + """Execute a single tool and put results in the task queue. + + Args: + agent: The agent executing the tool. + tool_use: Tool use metadata and inputs. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for tool execution. + task_id: Unique identifier for this task. + task_queue: Queue to put tool events into. + task_event: Event to signal when task can continue. + stop_event: Sentinel object to signal task completion. + """ + try: + events = ToolExecutor._stream_with_trace( + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + ) + async for event in events: + task_queue.put_nowait((task_id, event)) + await task_event.wait() + task_event.clear() + + finally: + task_queue.put_nowait((task_id, stop_event)) diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py new file mode 100644 index 000000000..55b26f6d3 --- /dev/null +++ b/src/strands/tools/executors/sequential.py @@ -0,0 +1,46 @@ +"""Sequential tool executor implementation.""" + +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ...telemetry.metrics import Trace +from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ._executor import ToolExecutor + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + + +class SequentialToolExecutor(ToolExecutor): + """Sequential tool executor.""" + + @override + async def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute tools sequentially. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + for tool_use in tool_uses: + events = ToolExecutor._stream_with_trace( + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + ) + async for event in events: + yield event diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 7e769c6d7..279e2a06e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -73,12 +73,6 @@ def mock_event_loop_cycle(): yield mock -@pytest.fixture -def mock_run_tool(): - with unittest.mock.patch("strands.agent.agent.run_tool") as mock: - yield mock - - @pytest.fixture def tool_registry(): return strands.tools.registry.ToolRegistry() @@ -888,9 +882,7 @@ def test_agent_init_with_no_model_or_model_id(): assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID -def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) - +def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: return system_prompt @@ -899,22 +891,12 @@ def function(system_prompt: str) -> str: mock_randint.return_value = 1 - agent.tool.system_prompter(system_prompt="tool prompt") - - mock_run_tool.assert_called_with( - agent, - { - "toolUseId": "tooluse_system_prompter_1", - "name": "system_prompter", - "input": {"system_prompt": "tool prompt"}, - }, - {"system_prompt": "tool prompt"}, - ) - + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result -def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, agenerator): tool_name = "system-prompter" @strands.tools.tool(name=tool_name) @@ -925,19 +907,9 @@ def function(system_prompt: str) -> str: mock_randint.return_value = 1 - agent.tool.system_prompter(system_prompt="tool prompt") - - # Verify the correct tool was invoked - assert mock_run_tool.call_count == 1 - tru_tool_use = mock_run_tool.call_args.args[1] - exp_tool_use = { - # Note that the tool-use uses the "python safe" name - "toolUseId": "tooluse_system_prompter_1", - # But the name of the tool is the one in the registry - "name": tool_name, - "input": {"system_prompt": "tool prompt"}, - } - assert tru_tool_use == exp_tool_use + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 191ab51ba..c76514ac8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,23 +1,20 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock, call, patch import pytest import strands import strands.telemetry -from strands.event_loop.event_loop import run_tool from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, ) -from strands.hooks import ( - HookProvider, - HookRegistry, -) +from strands.hooks import HookRegistry from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types.exceptions import ( ContextWindowOverflowException, @@ -131,7 +128,12 @@ def hook_provider(hook_registry): @pytest.fixture -def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry): +def tool_executor(): + return SequentialToolExecutor() + + +@pytest.fixture +def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor): mock = unittest.mock.Mock(name="agent") mock.config.cache_points = [] mock.model = model @@ -141,6 +143,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry + mock.tool_executor = tool_executor return mock @@ -812,260 +815,6 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a ) -@pytest.mark.asyncio -async def test_run_tool(agent, tool, alist): - process = run_tool( - agent, - tool_use={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}}, - invocation_state={}, - ) - - tru_result = (await alist(process))[-1] - exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]} - - assert tru_result == exp_result - - -@pytest.mark.asyncio -async def test_run_tool_missing_tool(agent, alist): - process = run_tool( - agent, - tool_use={"toolUseId": "missing", "name": "missing", "input": {}}, - invocation_state={}, - ) - - tru_events = await alist(process) - exp_events = [ - { - "toolUseId": "missing", - "status": "error", - "content": [{"text": "Unknown tool: missing"}], - }, - ] - - assert tru_events == exp_events - - -@pytest.mark.asyncio -async def test_run_tool_hooks(agent, hook_provider, tool_times_2, alist): - """Test that the correct hooks are emitted.""" - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - tool_use={"input": {"x": 5}, "name": "multiply_by_2", "toolUseId": "test"}, - invocation_state=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - exception=None, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - result={"toolUseId": "test", "status": "success", "content": [{"text": "10"}]}, - invocation_state=ANY, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hooks_on_missing_tool(agent, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - invocation_state=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - invocation_state=ANY, - result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - error = ValueError("Tool failed") - - failing_tool = MagicMock() - failing_tool.tool_name = "failing_tool" - - failing_tool.stream.side_effect = error - - tool_registry.register_tool(failing_tool) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "failing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=failing_tool, - tool_use={"input": {"x": 5}, "name": "failing_tool", "toolUseId": "test"}, - invocation_state=ANY, - result={"content": [{"text": "Error: Tool failed"}], "status": "error", "toolUseId": "test"}, - exception=error, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_before_tool_invocation_updates(agent, tool_times_5, hook_registry, hook_provider, alist): - """Test that modifying properties on BeforeToolInvocation takes effect.""" - - updated_tool_use = {"toolUseId": "modified", "name": "replacement_tool", "input": {"x": 3}} - - def modify_hook(event: BeforeToolInvocationEvent): - # Modify selected_tool to use replacement_tool - event.selected_tool = tool_times_5 - # Modify tool_use to change toolUseId - event.tool_use = updated_tool_use - - hook_registry.add_callback(BeforeToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "original", "name": "original_tool", "input": {"x": 1}}, - invocation_state={}, - ) - result = (await alist(process))[-1] - - # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2) - assert result == {"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]} - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_5, - tool_use=updated_tool_use, - invocation_state=ANY, - result={"content": [{"text": "15"}], "status": "success", "toolUseId": "modified"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates(agent, tool_times_2, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_update_result_with_missing_tool(agent, tool_registry, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - @strands.tool - def test_quota(): - return "9" - - tool_registry.register_tool(test_quota) - - class ExampleProvider(HookProvider): - def register_hooks(self, registry: "HookRegistry") -> None: - registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call) - registry.add_callback(AfterToolInvocationEvent, self.after_tool_call) - - def before_tool_call(self, event: BeforeToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.selected_tool = None - - def after_tool_call(self, event: AfterToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.result = { - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - } - - hook_registry.add_hook(ExampleProvider()) - - with patch.object(strands.event_loop.event_loop, "logger") as mock_logger: - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - - assert result == { - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - } - - assert mock_logger.debug.call_args_list == [ - call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), - call( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - "test_quota", - "test", - ), - ] - - @pytest.mark.asyncio async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider): """Test that model hooks are correctly emitted even when throttled.""" diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py new file mode 100644 index 000000000..1576b7578 --- /dev/null +++ b/tests/strands/tools/executors/conftest.py @@ -0,0 +1,116 @@ +import threading +import unittest.mock + +import pytest + +import strands +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.hooks import HookRegistry +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def hook_events(): + return [] + + +@pytest.fixture +def tool_hook(hook_events): + def callback(event): + hook_events.append(event) + return event + + return callback + + +@pytest.fixture +def hook_registry(tool_hook): + registry = HookRegistry() + registry.add_callback(BeforeToolInvocationEvent, tool_hook) + registry.add_callback(AfterToolInvocationEvent, tool_hook) + return registry + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def weather_tool(): + @strands.tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def temperature_tool(): + @strands.tool(name="temperature_tool") + def func(): + return "75F" + + return func + + +@pytest.fixture +def exception_tool(): + @strands.tool(name="exception_tool") + def func(): + pass + + async def mock_stream(_tool_use, _invocation_state): + raise RuntimeError("Tool error") + yield # make generator + + func.stream = mock_stream + return func + + +@pytest.fixture +def thread_tool(tool_events): + @strands.tool(name="thread_tool") + def func(): + tool_events.append({"thread_name": threading.current_thread().name}) + return "threaded" + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool): + registry = ToolRegistry() + registry.register_tool(weather_tool) + registry.register_tool(temperature_tool) + registry.register_tool(exception_tool) + registry.register_tool(thread_tool) + return registry + + +@pytest.fixture +def agent(tool_registry, hook_registry): + mock_agent = unittest.mock.Mock() + mock_agent.tool_registry = tool_registry + mock_agent.hooks = hook_registry + return mock_agent + + +@pytest.fixture +def tool_results(): + return [] + + +@pytest.fixture +def cycle_trace(): + return unittest.mock.Mock() + + +@pytest.fixture +def cycle_span(): + return unittest.mock.Mock() + + +@pytest.fixture +def invocation_state(): + return {} diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py new file mode 100644 index 000000000..7e0d6c2df --- /dev/null +++ b/tests/strands/tools/executors/test_concurrent.py @@ -0,0 +1,32 @@ +import pytest + +from strands.tools.executors import ConcurrentToolExecutor + + +@pytest.fixture +def executor(): + return ConcurrentToolExecutor() + + +@pytest.mark.asyncio +async def test_concurrent_executor_execute( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_uses = [ + {"name": "weather_tool", "toolUseId": "1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "2", "input": {}}, + ] + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId")) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ] + assert tru_events == exp_events + + tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) + exp_results = [exp_events[1], exp_events[3]] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py new file mode 100644 index 000000000..edbad3939 --- /dev/null +++ b/tests/strands/tools/executors/test_executor.py @@ -0,0 +1,144 @@ +import unittest.mock + +import pytest + +import strands +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.telemetry.metrics import Trace +from strands.tools.executors._executor import ToolExecutor + + +@pytest.fixture +def executor_cls(): + class ClsExecutor(ToolExecutor): + def _execute(self, _agent, _tool_uses, _tool_results, _invocation_state): + raise NotImplementedError + + return ClsExecutor + + +@pytest.fixture +def executor(executor_cls): + return executor_cls() + + +@pytest.fixture +def tracer(): + with unittest.mock.patch.object(strands.tools.executors._executor, "get_tracer") as mock_get_tracer: + yield mock_get_tracer.return_value + + +@pytest.mark.asyncio +async def test_executor_stream_yields_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist +): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_events = hook_events + exp_hook_events = [ + BeforeToolInvocationEvent( + agent=agent, + selected_tool=weather_tool, + tool_use=tool_use, + invocation_state=invocation_state, + ), + AfterToolInvocationEvent( + agent=agent, + selected_tool=weather_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + ), + ] + assert tru_hook_events == exp_hook_events + + +@pytest.mark.asyncio +async def test_executor_stream_yields_tool_error( + executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist +): + tool_use = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_after_event = hook_events[-1] + exp_hook_after_event = AfterToolInvocationEvent( + agent=agent, + selected_tool=exception_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + exception=unittest.mock.ANY, + ) + assert tru_hook_after_event == exp_hook_after_event + + +@pytest.mark.asyncio +async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results, invocation_state, hook_events, alist): + tool_use = {"name": "unknown_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_after_event = hook_events[-1] + exp_hook_after_event = AfterToolInvocationEvent( + agent=agent, + selected_tool=None, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + ) + assert tru_hook_after_event == exp_hook_after_event + + +@pytest.mark.asyncio +async def test_executor_stream_with_trace( + executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span) + tracer.end_tool_call_span.assert_called_once_with( + tracer.start_tool_call_span.return_value, + {"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"}, + ) + + cycle_trace.add_child.assert_called_once() + assert isinstance(cycle_trace.add_child.call_args[0][0], Trace) diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py new file mode 100644 index 000000000..d9b32c129 --- /dev/null +++ b/tests/strands/tools/executors/test_sequential.py @@ -0,0 +1,32 @@ +import pytest + +from strands.tools.executors import SequentialToolExecutor + + +@pytest.fixture +def executor(): + return SequentialToolExecutor() + + +@pytest.mark.asyncio +async def test_sequential_executor_execute( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_uses = [ + {"name": "weather_tool", "toolUseId": "1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "2", "input": {}}, + ] + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[1], exp_events[2]] + assert tru_results == exp_results diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py deleted file mode 100644 index 04d4ea657..000000000 --- a/tests/strands/tools/test_executor.py +++ /dev/null @@ -1,440 +0,0 @@ -import unittest.mock -import uuid - -import pytest - -import strands -import strands.telemetry -from strands.types.content import Message - - -@pytest.fixture(autouse=True) -def moto_autouse(moto_env): - _ = moto_env - - -@pytest.fixture -def tool_handler(request): - async def handler(tool_use): - yield {"event": "abc"} - yield { - **params, - "toolUseId": tool_use["toolUseId"], - } - - params = { - "content": [{"text": "test result"}], - "status": "success", - } - if hasattr(request, "param"): - params.update(request.param) - - return handler - - -@pytest.fixture -def tool_use(): - return {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}} - - -@pytest.fixture -def tool_uses(request, tool_use): - return request.param if hasattr(request, "param") else [tool_use] - - -@pytest.fixture -def mock_metrics_client(): - with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: - yield mock_metrics_client - - -@pytest.fixture -def event_loop_metrics(): - return strands.telemetry.metrics.EventLoopMetrics() - - -@pytest.fixture -def invalid_tool_use_ids(request): - return request.param if hasattr(request, "param") else [] - - -@pytest.fixture -def cycle_trace(): - with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") - - -@pytest.mark.asyncio -async def test_run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - - tru_events = await alist(stream) - exp_events = [ - {"event": "abc"}, - { - "content": [ - { - "text": "test result", - }, - ], - "status": "success", - "toolUseId": "t1", - }, - ] - - tru_results = tool_results - exp_results = [exp_events[-1]] - - assert tru_events == exp_events and tru_results == exp_results - - -@pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_invalid_tool( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - await alist(stream) - - tru_results = tool_results - exp_results = [] - - assert tru_results == exp_results - - -@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_failed_tool( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - await alist(stream) - - tru_results = tool_results - exp_results = [ - { - "content": [ - { - "text": "test result", - }, - ], - "status": "failed", - "toolUseId": "t1", - }, - ] - - assert tru_results == exp_results - - -@pytest.mark.parametrize( - ("tool_uses", "invalid_tool_use_ids"), - [ - ( - [ - { - "toolUseId": "t1", - "name": "test_tool_success", - "input": {"key": "value1"}, - }, - { - "toolUseId": "t2", - "name": "test_tool_invalid", - "input": {"key": "value2"}, - }, - ], - ["t2"], - ), - ], - indirect=True, -) -@pytest.mark.asyncio -async def test_run_tools_sequential( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, # tool_pool - ) - await alist(stream) - - tru_results = tool_results - exp_results = [ - { - "content": [ - { - "text": "test result", - }, - ], - "status": "success", - "toolUseId": "t1", - }, - ] - - assert tru_results == exp_results - - -def test_validate_and_prepare_tools(): - message: Message = { - "role": "assistant", - "content": [ - {"text": "value"}, - {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}}, - {"toolUse": {"toolUseId": "t2-invalid"}}, - ], - } - - tool_uses = [] - tool_results = [] - invalid_tool_use_ids = [] - - strands.tools.executor.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids - exp_tool_uses = [ - { - "input": { - "key": "value", - }, - "name": "test_tool", - "toolUseId": "t1", - }, - { - "name": "INVALID_TOOL_NAME", - "toolUseId": "t2-invalid", - }, - ] - exp_tool_results = [ - { - "content": [ - { - "text": "Error: tool name missing", - }, - ], - "status": "error", - "toolUseId": "t2-invalid", - }, - ] - exp_invalid_tool_use_ids = ["t2-invalid"] - - assert tru_tool_uses == exp_tool_uses - assert tru_tool_results == exp_tool_results - assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_success( - mock_get_tracer, - tool_handler, - tool_uses, - mock_metrics_client, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - """Test that run_tools creates and ends a span on successful execution.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tool - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify span was created with the parent span - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) - - # Verify span was ended with the tool result - mock_tracer.end_tool_call_span.assert_called_once() - args, _ = mock_tracer.end_tool_call_span.call_args - assert args[0] == mock_span - assert args[1]["status"] == "success" - assert args[1]["content"][0]["text"] == "test result" - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_failure( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - """Test that run_tools creates and ends a span on tool failure.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tool - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify span was created with the parent span - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) - - # Verify span was ended with the tool result - mock_tracer.end_tool_call_span.assert_called_once() - args, _ = mock_tracer.end_tool_call_span.call_args - assert args[0] == mock_span - assert args[1]["status"] == "failed" - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.parametrize( - ("tool_uses", "invalid_tool_use_ids"), - [ - ( - [ - { - "toolUseId": "t1", - "name": "test_tool_success", - "input": {"key": "value1"}, - }, - { - "toolUseId": "t2", - "name": "test_tool_also_success", - "input": {"key": "value2"}, - }, - ], - [], - ), - ], - indirect=True, -) -@pytest.mark.asyncio -async def test_run_tools_concurrent_execution_with_spans( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - # Setup mock tracer and spans - mock_tracer = unittest.mock.MagicMock() - mock_span1 = unittest.mock.MagicMock() - mock_span2 = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.side_effect = [mock_span1, mock_span2] - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tools - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify spans were created for both tools - assert mock_tracer.start_tool_call_span.call_count == 2 - mock_tracer.start_tool_call_span.assert_has_calls( - [ - unittest.mock.call(tool_uses[0], parent_span), - unittest.mock.call(tool_uses[1], parent_span), - ], - any_order=True, - ) - - # Verify spans were ended for both tools - assert mock_tracer.end_tool_call_span.call_count == 2 diff --git a/tests/strands/tools/test_validator.py b/tests/strands/tools/test_validator.py new file mode 100644 index 000000000..46e5e15f3 --- /dev/null +++ b/tests/strands/tools/test_validator.py @@ -0,0 +1,50 @@ +from strands.tools import _validator +from strands.types.content import Message + + +def test_validate_and_prepare_tools(): + message: Message = { + "role": "assistant", + "content": [ + {"text": "value"}, + {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}}, + {"toolUse": {"toolUseId": "t2-invalid"}}, + ], + } + + tool_uses = [] + tool_results = [] + invalid_tool_use_ids = [] + + _validator.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids + exp_tool_uses = [ + { + "input": { + "key": "value", + }, + "name": "test_tool", + "toolUseId": "t1", + }, + { + "name": "INVALID_TOOL_NAME", + "toolUseId": "t2-invalid", + }, + ] + exp_tool_results = [ + { + "content": [ + { + "text": "Error: tool name missing", + }, + ], + "status": "error", + "toolUseId": "t2-invalid", + }, + ] + exp_invalid_tool_use_ids = ["t2-invalid"] + + assert tru_tool_uses == exp_tool_uses + assert tru_tool_results == exp_tool_results + assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids diff --git a/tests_integ/tools/executors/test_concurrent.py b/tests_integ/tools/executors/test_concurrent.py new file mode 100644 index 000000000..27dd468e0 --- /dev/null +++ b/tests_integ/tools/executors/test_concurrent.py @@ -0,0 +1,61 @@ +import asyncio + +import pytest + +import strands +from strands import Agent +from strands.tools.executors import ConcurrentToolExecutor + + +@pytest.fixture +def tool_executor(): + return ConcurrentToolExecutor() + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def time_tool(tool_events): + @strands.tool(name="time_tool") + async def func(): + tool_events.append({"name": "time_tool", "event": "start"}) + await asyncio.sleep(2) + tool_events.append({"name": "time_tool", "event": "end"}) + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(tool_events): + @strands.tool(name="weather_tool") + async def func(): + tool_events.append({"name": "weather_tool", "event": "start"}) + await asyncio.sleep(1) + tool_events.append({"name": "weather_tool", "event": "end"}) + + return "sunny" + + return func + + +@pytest.fixture +def agent(tool_executor, time_tool, weather_tool): + return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor) + + +@pytest.mark.asyncio +async def test_agent_invoke_async_tool_executor(agent, tool_events): + await agent.invoke_async("What is the time and weather in New York?") + + tru_events = tool_events + exp_events = [ + {"name": "time_tool", "event": "start"}, + {"name": "weather_tool", "event": "start"}, + {"name": "weather_tool", "event": "end"}, + {"name": "time_tool", "event": "end"}, + ] + assert tru_events == exp_events diff --git a/tests_integ/tools/executors/test_sequential.py b/tests_integ/tools/executors/test_sequential.py new file mode 100644 index 000000000..82fc51a59 --- /dev/null +++ b/tests_integ/tools/executors/test_sequential.py @@ -0,0 +1,61 @@ +import asyncio + +import pytest + +import strands +from strands import Agent +from strands.tools.executors import SequentialToolExecutor + + +@pytest.fixture +def tool_executor(): + return SequentialToolExecutor() + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def time_tool(tool_events): + @strands.tool(name="time_tool") + async def func(): + tool_events.append({"name": "time_tool", "event": "start"}) + await asyncio.sleep(2) + tool_events.append({"name": "time_tool", "event": "end"}) + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(tool_events): + @strands.tool(name="weather_tool") + async def func(): + tool_events.append({"name": "weather_tool", "event": "start"}) + await asyncio.sleep(1) + tool_events.append({"name": "weather_tool", "event": "end"}) + + return "sunny" + + return func + + +@pytest.fixture +def agent(tool_executor, time_tool, weather_tool): + return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor) + + +@pytest.mark.asyncio +async def test_agent_invoke_async_tool_executor(agent, tool_events): + await agent.invoke_async("What is the time and weather in New York?") + + tru_events = tool_events + exp_events = [ + {"name": "time_tool", "event": "start"}, + {"name": "time_tool", "event": "end"}, + {"name": "weather_tool", "event": "start"}, + {"name": "weather_tool", "event": "end"}, + ] + assert tru_events == exp_events