@@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
61
61
from pydantic import BaseModel , Field , create_model
62
62
from typing_extensions import override
63
63
64
- from ..types .tools import AgentTool , JSONSchema , ToolGenerator , ToolSpec , ToolUse
64
+ from ..types .tools import AgentTool , JSONSchema , ToolContext , ToolGenerator , ToolSpec , ToolUse
65
65
66
66
logger = logging .getLogger (__name__ )
67
67
@@ -84,16 +84,18 @@ class FunctionToolMetadata:
84
84
validate tool usage.
85
85
"""
86
86
87
- def __init__ (self , func : Callable [..., Any ]) -> None :
87
+ def __init__ (self , func : Callable [..., Any ], context_param : str | None = None ) -> None :
88
88
"""Initialize with the function to process.
89
89
90
90
Args:
91
91
func: The function to extract metadata from.
92
92
Can be a standalone function or a class method.
93
+ context_param: Name of the context parameter to inject, if any.
93
94
"""
94
95
self .func = func
95
96
self .signature = inspect .signature (func )
96
97
self .type_hints = get_type_hints (func )
98
+ self ._context_param = context_param
97
99
98
100
# Parse the docstring with docstring_parser
99
101
doc_str = inspect .getdoc (func ) or ""
@@ -113,16 +115,16 @@ def _create_input_model(self) -> Type[BaseModel]:
113
115
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
114
116
validate input data before passing it to the function.
115
117
116
- Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
118
+ Special parameters that can be automatically injected are excluded from the model.
117
119
118
120
Returns:
119
121
A Pydantic BaseModel class customized for the function's parameters.
120
122
"""
121
123
field_definitions : dict [str , Any ] = {}
122
124
123
125
for name , param in self .signature .parameters .items ():
124
- # Skip special parameters
125
- if name in ( " self" , "cls" , "agent" ):
126
+ # Skip parameters that will be automatically injected
127
+ if self . _is_special_parameter ( name ):
126
128
continue
127
129
128
130
# Get parameter type and default
@@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
252
254
error_msg = str (e )
253
255
raise ValueError (f"Validation failed for input parameters: { error_msg } " ) from e
254
256
257
+ def inject_special_parameters (
258
+ self , validated_input : dict [str , Any ], tool_use : ToolUse , invocation_state : dict [str , Any ]
259
+ ) -> None :
260
+ """Inject special framework-provided parameters into the validated input.
261
+
262
+ This method automatically provides framework-level context to tools that request it
263
+ through their function signature.
264
+
265
+ Args:
266
+ validated_input: The validated input parameters (modified in place).
267
+ tool_use: The tool use request containing tool invocation details.
268
+ invocation_state: Context for the tool invocation, including agent state.
269
+ """
270
+ if self ._context_param and self ._context_param in self .signature .parameters :
271
+ tool_context = ToolContext (tool_use = tool_use , agent = invocation_state ["agent" ])
272
+ validated_input [self ._context_param ] = tool_context
273
+
274
+ # Inject agent if requested (backward compatibility)
275
+ if "agent" in self .signature .parameters and "agent" in invocation_state :
276
+ validated_input ["agent" ] = invocation_state ["agent" ]
277
+
278
+ def _is_special_parameter (self , param_name : str ) -> bool :
279
+ """Check if a parameter should be automatically injected by the framework or is a standard Python method param.
280
+
281
+ Special parameters include:
282
+ - Standard Python method parameters: self, cls
283
+ - Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)
284
+
285
+ Args:
286
+ param_name: The name of the parameter to check.
287
+
288
+ Returns:
289
+ True if the parameter should be excluded from input validation and
290
+ handled specially during tool execution.
291
+ """
292
+ special_params = {"self" , "cls" , "agent" }
293
+
294
+ # Add context parameter if configured
295
+ if self ._context_param :
296
+ special_params .add (self ._context_param )
297
+
298
+ return param_name in special_params
299
+
255
300
256
301
P = ParamSpec ("P" ) # Captures all parameters
257
302
R = TypeVar ("R" ) # Return type
@@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
402
447
# Validate input against the Pydantic model
403
448
validated_input = self ._metadata .validate_input (tool_input )
404
449
405
- # Pass along the agent if provided and expected by the function
406
- if "agent" in invocation_state and "agent" in self ._metadata .signature .parameters :
407
- validated_input ["agent" ] = invocation_state .get ("agent" )
450
+ # Inject special framework-provided parameters
451
+ self ._metadata .inject_special_parameters (validated_input , tool_use , invocation_state )
408
452
409
453
# "Too few arguments" expected, hence the type ignore
410
454
if inspect .iscoroutinefunction (self ._tool_func ):
@@ -474,6 +518,7 @@ def tool(
474
518
description : Optional [str ] = None ,
475
519
inputSchema : Optional [JSONSchema ] = None ,
476
520
name : Optional [str ] = None ,
521
+ context : bool | str = False ,
477
522
) -> Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]: ...
478
523
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479
524
# call site, but the actual implementation handles that and it's not representable via the type-system
@@ -482,6 +527,7 @@ def tool( # type: ignore
482
527
description : Optional [str ] = None ,
483
528
inputSchema : Optional [JSONSchema ] = None ,
484
529
name : Optional [str ] = None ,
530
+ context : bool | str = False ,
485
531
) -> Union [DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]]:
486
532
"""Decorator that transforms a Python function into a Strands tool.
487
533
@@ -507,6 +553,9 @@ def tool( # type: ignore
507
553
description: Optional custom description to override the function's docstring.
508
554
inputSchema: Optional custom JSON schema to override the automatically generated schema.
509
555
name: Optional custom name to override the function's name.
556
+ context: When provided, places an object in the designated parameter. If True, the param name
557
+ defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
558
+ the param name.
510
559
511
560
Returns:
512
561
An AgentTool that also mimics the original function when invoked
@@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:
536
585
537
586
Example with parameters:
538
587
```python
539
- @tool(name="custom_tool", description="A tool with a custom name and description")
540
- def my_tool(name: str, count: int = 1) -> str:
541
- return f"Processed {name} {count} times"
588
+ @tool(name="custom_tool", description="A tool with a custom name and description", context=True)
589
+ def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str:
590
+ tool_id = tool_context["tool_use"]["toolUseId"]
591
+ return f"Processed {name} {count} times with tool ID {tool_id}"
542
592
```
543
593
"""
544
594
545
595
def decorator (f : T ) -> "DecoratedFunctionTool[P, R]" :
596
+ # Resolve context parameter name
597
+ if isinstance (context , bool ):
598
+ context_param = "tool_context" if context else None
599
+ else :
600
+ context_param = context .strip ()
601
+ if not context_param :
602
+ raise ValueError ("Context parameter name cannot be empty" )
603
+
546
604
# Create function tool metadata
547
- tool_meta = FunctionToolMetadata (f )
605
+ tool_meta = FunctionToolMetadata (f , context_param )
548
606
tool_spec = tool_meta .extract_metadata ()
549
607
if name is not None :
550
608
tool_spec ["name" ] = name
0 commit comments