11# Standard library imports
2+ import base64
23from collections .abc import Callable
34from contextlib import AsyncExitStack , asynccontextmanager
45from enum import Enum
5- from typing import Any , Literal
6+ from typing import Any , Literal , Optional , override
67
8+ from pydantic import Field , PrivateAttr
9+
10+ import cloudpickle
711from agents import RunContextWrapper , RunResult , RunResultStreaming
812from agents .mcp import MCPServerStdio , MCPServerStdioParams
913from agents .model_settings import ModelSettings as OAIModelSettings
@@ -41,12 +45,92 @@ class FunctionTool(BaseModelWithTraceParams):
4145 name : str
4246 description : str
4347 params_json_schema : dict [str , Any ]
44- on_invoke_tool : Callable [[ RunContextWrapper , str ], Any ]
48+
4549 strict_json_schema : bool = True
4650 is_enabled : bool = True
4751
52+ _on_invoke_tool : Callable [[RunContextWrapper , str ], Any ] = PrivateAttr ()
53+ on_invoke_tool_serialized : str = Field (
54+ default = "" ,
55+ description = (
56+ "Normally will be set automatically during initialization and"
57+ " doesn't need to be passed. "
58+ "Instead, pass `on_invoke_tool` to the constructor. "
59+ "See the __init__ method for details."
60+ ),
61+ )
62+
63+ def __init__ (
64+ self ,
65+ * ,
66+ on_invoke_tool : Optional [Callable [[RunContextWrapper , str ], Any ]] = None ,
67+ ** data ,
68+ ):
69+ """
70+ Initialize a FunctionTool with hacks to support serialization of the
71+ on_invoke_tool callable arg. This is required to facilitate over-the-wire
72+ communication of this object to/from temporal services/workers.
73+
74+ Args:
75+ on_invoke_tool: The callable to invoke when the tool is called.
76+ **data: Additional data to initialize the FunctionTool.
77+ """
78+ super ().__init__ (** data )
79+ if not on_invoke_tool :
80+ if not self .on_invoke_tool_serialized :
81+ raise ValueError (
82+ "One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set"
83+ )
84+ else :
85+ on_invoke_tool = self ._deserialize_callable (
86+ self .on_invoke_tool_serialized
87+ )
88+ else :
89+ self .on_invoke_tool_serialized = self ._serialize_callable (on_invoke_tool )
90+
91+ self ._on_invoke_tool = on_invoke_tool
92+
93+ @classmethod
94+ def _deserialize_callable (
95+ cls , serialized : str
96+ ) -> Callable [[RunContextWrapper , str ], Any ]:
97+ encoded = serialized .encode ()
98+ serialized_bytes = base64 .b64decode (encoded )
99+ return cloudpickle .loads (serialized_bytes )
100+
101+ @classmethod
102+ def _serialize_callable (cls , func : Callable ) -> str :
103+ serialized_bytes = cloudpickle .dumps (func )
104+ encoded = base64 .b64encode (serialized_bytes )
105+ return encoded .decode ()
106+
107+ @property
108+ def on_invoke_tool (self ) -> Callable [[RunContextWrapper , str ], Any ]:
109+ if self ._on_invoke_tool is None and self .on_invoke_tool_serialized :
110+ self ._on_invoke_tool = self ._deserialize_callable (
111+ self .on_invoke_tool_serialized
112+ )
113+ return self ._on_invoke_tool
114+
115+ @on_invoke_tool .setter
116+ def on_invoke_tool (self , value : Callable [[RunContextWrapper , str ], Any ]):
117+ self .on_invoke_tool_serialized = self ._serialize_callable (value )
118+ self ._on_invoke_tool = value
119+
48120 def to_oai_function_tool (self ) -> OAIFunctionTool :
49- return OAIFunctionTool (** self .model_dump (exclude = ["trace_id" , "parent_span_id" ]))
121+ """Convert to OpenAI function tool, excluding serialization fields."""
122+ # Create a dictionary with only the fields OAIFunctionTool expects
123+ data = self .model_dump (
124+ exclude = {
125+ "trace_id" ,
126+ "parent_span_id" ,
127+ "_on_invoke_tool" ,
128+ "on_invoke_tool_serialized" ,
129+ }
130+ )
131+ # Add the callable for OAI tool since properties are not serialized
132+ data ["on_invoke_tool" ] = self .on_invoke_tool
133+ return OAIFunctionTool (** data )
50134
51135
52136class ModelSettings (BaseModelWithTraceParams ):
@@ -68,7 +152,9 @@ class ModelSettings(BaseModelWithTraceParams):
68152 extra_args : dict [str , Any ] | None = None
69153
70154 def to_oai_model_settings (self ) -> OAIModelSettings :
71- return OAIModelSettings (** self .model_dump (exclude = ["trace_id" , "parent_span_id" ]))
155+ return OAIModelSettings (
156+ ** self .model_dump (exclude = ["trace_id" , "parent_span_id" ])
157+ )
72158
73159
74160class RunAgentParams (BaseModelWithTraceParams ):
0 commit comments