Skip to content

Commit 606f657

Browse files
feat: expose tool_use and agent through ToolContext to decorated tools (#557)
1 parent 1c7257b commit 606f657

File tree

5 files changed

+312
-15
lines changed

5 files changed

+312
-15
lines changed

src/strands/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from . import agent, models, telemetry, types
44
from .agent.agent import Agent
55
from .tools.decorator import tool
6+
from .types.tools import ToolContext
67

7-
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
8+
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"]

src/strands/tools/decorator.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6161
from pydantic import BaseModel, Field, create_model
6262
from typing_extensions import override
6363

64-
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
64+
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse
6565

6666
logger = logging.getLogger(__name__)
6767

@@ -84,16 +84,18 @@ class FunctionToolMetadata:
8484
validate tool usage.
8585
"""
8686

87-
def __init__(self, func: Callable[..., Any]) -> None:
87+
def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None:
8888
"""Initialize with the function to process.
8989
9090
Args:
9191
func: The function to extract metadata from.
9292
Can be a standalone function or a class method.
93+
context_param: Name of the context parameter to inject, if any.
9394
"""
9495
self.func = func
9596
self.signature = inspect.signature(func)
9697
self.type_hints = get_type_hints(func)
98+
self._context_param = context_param
9799

98100
# Parse the docstring with docstring_parser
99101
doc_str = inspect.getdoc(func) or ""
@@ -113,16 +115,16 @@ def _create_input_model(self) -> Type[BaseModel]:
113115
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
114116
validate input data before passing it to the function.
115117
116-
Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
118+
Special parameters that can be automatically injected are excluded from the model.
117119
118120
Returns:
119121
A Pydantic BaseModel class customized for the function's parameters.
120122
"""
121123
field_definitions: dict[str, Any] = {}
122124

123125
for name, param in self.signature.parameters.items():
124-
# Skip special parameters
125-
if name in ("self", "cls", "agent"):
126+
# Skip parameters that will be automatically injected
127+
if self._is_special_parameter(name):
126128
continue
127129

128130
# Get parameter type and default
@@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
252254
error_msg = str(e)
253255
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e
254256

257+
def inject_special_parameters(
258+
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
259+
) -> None:
260+
"""Inject special framework-provided parameters into the validated input.
261+
262+
This method automatically provides framework-level context to tools that request it
263+
through their function signature.
264+
265+
Args:
266+
validated_input: The validated input parameters (modified in place).
267+
tool_use: The tool use request containing tool invocation details.
268+
invocation_state: Context for the tool invocation, including agent state.
269+
"""
270+
if self._context_param and self._context_param in self.signature.parameters:
271+
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
272+
validated_input[self._context_param] = tool_context
273+
274+
# Inject agent if requested (backward compatibility)
275+
if "agent" in self.signature.parameters and "agent" in invocation_state:
276+
validated_input["agent"] = invocation_state["agent"]
277+
278+
def _is_special_parameter(self, param_name: str) -> bool:
279+
"""Check if a parameter should be automatically injected by the framework or is a standard Python method param.
280+
281+
Special parameters include:
282+
- Standard Python method parameters: self, cls
283+
- Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)
284+
285+
Args:
286+
param_name: The name of the parameter to check.
287+
288+
Returns:
289+
True if the parameter should be excluded from input validation and
290+
handled specially during tool execution.
291+
"""
292+
special_params = {"self", "cls", "agent"}
293+
294+
# Add context parameter if configured
295+
if self._context_param:
296+
special_params.add(self._context_param)
297+
298+
return param_name in special_params
299+
255300

256301
P = ParamSpec("P") # Captures all parameters
257302
R = TypeVar("R") # Return type
@@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
402447
# Validate input against the Pydantic model
403448
validated_input = self._metadata.validate_input(tool_input)
404449

405-
# Pass along the agent if provided and expected by the function
406-
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
407-
validated_input["agent"] = invocation_state.get("agent")
450+
# Inject special framework-provided parameters
451+
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)
408452

409453
# "Too few arguments" expected, hence the type ignore
410454
if inspect.iscoroutinefunction(self._tool_func):
@@ -474,6 +518,7 @@ def tool(
474518
description: Optional[str] = None,
475519
inputSchema: Optional[JSONSchema] = None,
476520
name: Optional[str] = None,
521+
context: bool | str = False,
477522
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
478523
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479524
# call site, but the actual implementation handles that and it's not representable via the type-system
@@ -482,6 +527,7 @@ def tool( # type: ignore
482527
description: Optional[str] = None,
483528
inputSchema: Optional[JSONSchema] = None,
484529
name: Optional[str] = None,
530+
context: bool | str = False,
485531
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
486532
"""Decorator that transforms a Python function into a Strands tool.
487533
@@ -507,6 +553,9 @@ def tool( # type: ignore
507553
description: Optional custom description to override the function's docstring.
508554
inputSchema: Optional custom JSON schema to override the automatically generated schema.
509555
name: Optional custom name to override the function's name.
556+
context: When provided, places an object in the designated parameter. If True, the param name
557+
defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
558+
the param name.
510559
511560
Returns:
512561
An AgentTool that also mimics the original function when invoked
@@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:
536585
537586
Example with parameters:
538587
```python
539-
@tool(name="custom_tool", description="A tool with a custom name and description")
540-
def my_tool(name: str, count: int = 1) -> str:
541-
return f"Processed {name} {count} times"
588+
@tool(name="custom_tool", description="A tool with a custom name and description", context=True)
589+
def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str:
590+
tool_id = tool_context["tool_use"]["toolUseId"]
591+
return f"Processed {name} {count} times with tool ID {tool_id}"
542592
```
543593
"""
544594

545595
def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
596+
# Resolve context parameter name
597+
if isinstance(context, bool):
598+
context_param = "tool_context" if context else None
599+
else:
600+
context_param = context.strip()
601+
if not context_param:
602+
raise ValueError("Context parameter name cannot be empty")
603+
546604
# Create function tool metadata
547-
tool_meta = FunctionToolMetadata(f)
605+
tool_meta = FunctionToolMetadata(f, context_param)
548606
tool_spec = tool_meta.extract_metadata()
549607
if name is not None:
550608
tool_spec["name"] = name

src/strands/types/tools.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
"""
77

88
from abc import ABC, abstractmethod
9-
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
9+
from dataclasses import dataclass
10+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
1011

1112
from typing_extensions import TypedDict
1213

1314
from .media import DocumentContent, ImageContent
1415

16+
if TYPE_CHECKING:
17+
from .. import Agent
18+
1519
JSONSchema = dict
1620
"""Type alias for JSON Schema dictionaries."""
1721

@@ -117,6 +121,27 @@ class ToolChoiceTool(TypedDict):
117121
name: str
118122

119123

124+
@dataclass
125+
class ToolContext:
126+
"""Context object containing framework-provided data for decorated tools.
127+
128+
This object provides access to framework-level information that may be useful
129+
for tool implementations.
130+
131+
Attributes:
132+
tool_use: The complete ToolUse object containing tool invocation details.
133+
agent: The Agent instance executing this tool, providing access to conversation history,
134+
model configuration, and other agent state.
135+
136+
Note:
137+
This class is intended to be instantiated by the SDK. Direct construction by users
138+
is not supported and may break in future versions as new fields are added.
139+
"""
140+
141+
tool_use: ToolUse
142+
agent: "Agent"
143+
144+
120145
ToolChoice = Union[
121146
dict[Literal["auto"], ToolChoiceAuto],
122147
dict[Literal["any"], ToolChoiceAny],

tests/strands/tools/test_decorator.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99

1010
import strands
11-
from strands.types.tools import ToolUse
11+
from strands import Agent
12+
from strands.types.tools import AgentTool, ToolContext, ToolUse
1213

1314

1415
@pytest.fixture(scope="module")
@@ -1036,3 +1037,159 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
10361037
result = (await alist(stream))[-1]
10371038
assert result["status"] == "success"
10381039
assert "NoneType: None" in result["content"][0]["text"]
1040+
1041+
1042+
async def _run_context_injection_test(context_tool: AgentTool):
1043+
"""Common test logic for context injection tests."""
1044+
tool: AgentTool = context_tool
1045+
generator = tool.stream(
1046+
tool_use={
1047+
"toolUseId": "test-id",
1048+
"name": "context_tool",
1049+
"input": {
1050+
"message": "some_message" # note that we do not include agent nor tool context
1051+
},
1052+
},
1053+
invocation_state={
1054+
"agent": Agent(name="test_agent"),
1055+
},
1056+
)
1057+
tool_results = [value async for value in generator]
1058+
1059+
assert len(tool_results) == 1
1060+
tool_result = tool_results[0]
1061+
1062+
assert tool_result == {
1063+
"status": "success",
1064+
"content": [
1065+
{"text": "Tool 'context_tool' (ID: test-id)"},
1066+
{"text": "injected agent 'test_agent' processed: some_message"},
1067+
{"text": "context agent 'test_agent'"}
1068+
],
1069+
"toolUseId": "test-id",
1070+
}
1071+
1072+
1073+
@pytest.mark.asyncio
1074+
async def test_tool_context_injection_default():
1075+
"""Test that ToolContext is properly injected with default parameter name (tool_context)."""
1076+
1077+
@strands.tool(context=True)
1078+
def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
1079+
"""Tool that uses ToolContext to access tool_use_id."""
1080+
tool_use_id = tool_context.tool_use["toolUseId"]
1081+
tool_name = tool_context.tool_use["name"]
1082+
agent_from_tool_context = tool_context.agent
1083+
1084+
return {
1085+
"status": "success",
1086+
"content": [
1087+
{"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
1088+
{"text": f"injected agent '{agent.name}' processed: {message}"},
1089+
{"text": f"context agent '{agent_from_tool_context.name}'"},
1090+
],
1091+
}
1092+
1093+
await _run_context_injection_test(context_tool)
1094+
1095+
1096+
@pytest.mark.asyncio
1097+
async def test_tool_context_injection_custom_name():
1098+
"""Test that ToolContext is properly injected with custom parameter name."""
1099+
1100+
@strands.tool(context="custom_context_name")
1101+
def context_tool(message: str, agent: Agent, custom_context_name: ToolContext) -> dict:
1102+
"""Tool that uses ToolContext to access tool_use_id."""
1103+
tool_use_id = custom_context_name.tool_use["toolUseId"]
1104+
tool_name = custom_context_name.tool_use["name"]
1105+
agent_from_tool_context = custom_context_name.agent
1106+
1107+
return {
1108+
"status": "success",
1109+
"content": [
1110+
{"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
1111+
{"text": f"injected agent '{agent.name}' processed: {message}"},
1112+
{"text": f"context agent '{agent_from_tool_context.name}'"},
1113+
],
1114+
}
1115+
1116+
await _run_context_injection_test(context_tool)
1117+
1118+
1119+
@pytest.mark.asyncio
1120+
async def test_tool_context_injection_disabled_missing_parameter():
1121+
"""Test that when context=False, missing tool_context parameter causes validation error."""
1122+
1123+
@strands.tool(context=False)
1124+
def context_tool(message: str, agent: Agent, tool_context: str) -> dict:
1125+
"""Tool that expects tool_context as a regular string parameter."""
1126+
return {
1127+
"status": "success",
1128+
"content": [
1129+
{"text": f"Message: {message}"},
1130+
{"text": f"Agent: {agent.name}"},
1131+
{"text": f"Tool context string: {tool_context}"},
1132+
],
1133+
}
1134+
1135+
# Verify that missing tool_context parameter causes validation error
1136+
tool: AgentTool = context_tool
1137+
generator = tool.stream(
1138+
tool_use={
1139+
"toolUseId": "test-id",
1140+
"name": "context_tool",
1141+
"input": {
1142+
"message": "some_message"
1143+
# Missing tool_context parameter - should cause validation error instead of being auto injected
1144+
},
1145+
},
1146+
invocation_state={
1147+
"agent": Agent(name="test_agent"),
1148+
},
1149+
)
1150+
tool_results = [value async for value in generator]
1151+
1152+
assert len(tool_results) == 1
1153+
tool_result = tool_results[0]
1154+
1155+
# Should get a validation error because tool_context is required but not provided
1156+
assert tool_result["status"] == "error"
1157+
assert "tool_context" in tool_result["content"][0]["text"].lower()
1158+
assert "validation" in tool_result["content"][0]["text"].lower()
1159+
1160+
1161+
@pytest.mark.asyncio
1162+
async def test_tool_context_injection_disabled_string_parameter():
1163+
"""Test that when context=False, tool_context can be passed as a string parameter."""
1164+
1165+
@strands.tool(context=False)
1166+
def context_tool(message: str, agent: Agent, tool_context: str) -> str:
1167+
"""Tool that expects tool_context as a regular string parameter."""
1168+
return "success"
1169+
1170+
# Verify that providing tool_context as a string works correctly
1171+
tool: AgentTool = context_tool
1172+
generator = tool.stream(
1173+
tool_use={
1174+
"toolUseId": "test-id-2",
1175+
"name": "context_tool",
1176+
"input": {
1177+
"message": "some_message",
1178+
"tool_context": "my_custom_context_string"
1179+
},
1180+
},
1181+
invocation_state={
1182+
"agent": Agent(name="test_agent"),
1183+
},
1184+
)
1185+
tool_results = [value async for value in generator]
1186+
1187+
assert len(tool_results) == 1
1188+
tool_result = tool_results[0]
1189+
1190+
# Should succeed with the string parameter
1191+
assert tool_result == {
1192+
"status": "success",
1193+
"content": [{"text": "success"}],
1194+
"toolUseId": "test-id-2",
1195+
}

0 commit comments

Comments
 (0)