Skip to content

Commit c9eb040

Browse files
committed
fix: enable FunctionTool serialization for Temporal worker nodes
- RunAgent*Params objects must be serializable for over-the-wire transmission to Temporal workers/backend. Previous implementation failed when users specified FunctionTool with callable on_invoke_tool params. - This commit adds cloudpickle-based serialization support to resolve serialization errors - During testing, also had to pin OpenAI to v1.99.9 to avoid LiteLLM incompatibility issue ([#13711](BerriAI/litellm#13711))
1 parent 4e9bb87 commit c9eb040

File tree

6 files changed

+364
-9
lines changed

6 files changed

+364
-9
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ dependencies = [
3939
"pytest-asyncio>=1.0.0",
4040
"scale-gp-beta==0.1.0a20",
4141
"ipykernel>=6.29.5",
42-
"openai>=1.99.9",
42+
"openai==1.99.9", # anything higher than 1.99.9 breaks litellm - https://github.com/BerriAI/litellm/issues/13711
43+
"cloudpickle>=3.1.1",
4344
]
4445
requires-python = ">= 3.12,<4"
4546
classifiers = [

requirements-dev.lock

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ click==8.2.1
5353
# via litellm
5454
# via typer
5555
# via uvicorn
56+
cloudpickle==3.1.1
57+
# via agentex-sdk
5658
colorama==0.4.6
5759
# via griffe
5860
colorlog==6.7.0
@@ -179,9 +181,10 @@ oauthlib==3.3.1
179181
# via kubernetes
180182
# via requests-oauthlib
181183
openai==1.99.9
184+
# via agentex-sdk
182185
# via litellm
183186
# via openai-agents
184-
openai-agents==0.2.6
187+
openai-agents==0.2.7
185188
# via agentex-sdk
186189
packaging==23.2
187190
# via huggingface-hub

requirements.lock

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ click==8.2.1
5151
# via litellm
5252
# via typer
5353
# via uvicorn
54+
cloudpickle==3.1.1
55+
# via agentex-sdk
5456
colorama==0.4.6
5557
# via griffe
5658
comm==0.2.3
@@ -162,9 +164,10 @@ oauthlib==3.3.1
162164
# via kubernetes
163165
# via requests-oauthlib
164166
openai==1.99.9
167+
# via agentex-sdk
165168
# via litellm
166169
# via openai-agents
167-
openai-agents==0.2.6
170+
openai-agents==0.2.7
168171
# via agentex-sdk
169172
packaging==25.0
170173
# via huggingface-hub

src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Standard library imports
2+
import base64
23
from collections.abc import Callable
34
from contextlib import AsyncExitStack, asynccontextmanager
45
from enum import Enum
5-
from typing import Any, Literal
6+
from typing import Any, Literal, Optional, override
67

8+
from pydantic import Field, PrivateAttr
9+
10+
import cloudpickle
711
from agents import RunContextWrapper, RunResult, RunResultStreaming
812
from agents.mcp import MCPServerStdio, MCPServerStdioParams
913
from agents.model_settings import ModelSettings as OAIModelSettings
@@ -41,12 +45,92 @@ class FunctionTool(BaseModelWithTraceParams):
4145
name: str
4246
description: str
4347
params_json_schema: dict[str, Any]
44-
on_invoke_tool: Callable[[RunContextWrapper, str], Any]
48+
4549
strict_json_schema: bool = True
4650
is_enabled: bool = True
4751

52+
_on_invoke_tool: Callable[[RunContextWrapper, str], Any] = PrivateAttr()
53+
on_invoke_tool_serialized: str = Field(
54+
default="",
55+
description=(
56+
"Normally will be set automatically during initialization and"
57+
" doesn't need to be passed. "
58+
"Instead, pass `on_invoke_tool` to the constructor. "
59+
"See the __init__ method for details."
60+
),
61+
)
62+
63+
def __init__(
64+
self,
65+
*,
66+
on_invoke_tool: Optional[Callable[[RunContextWrapper, str], Any]] = None,
67+
**data,
68+
):
69+
"""
70+
Initialize a FunctionTool with hacks to support serialization of the
71+
on_invoke_tool callable arg. This is required to facilitate over-the-wire
72+
communication of this object to/from temporal services/workers.
73+
74+
Args:
75+
on_invoke_tool: The callable to invoke when the tool is called.
76+
**data: Additional data to initialize the FunctionTool.
77+
"""
78+
super().__init__(**data)
79+
if not on_invoke_tool:
80+
if not self.on_invoke_tool_serialized:
81+
raise ValueError(
82+
"One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set"
83+
)
84+
else:
85+
on_invoke_tool = self._deserialize_callable(
86+
self.on_invoke_tool_serialized
87+
)
88+
else:
89+
self.on_invoke_tool_serialized = self._serialize_callable(on_invoke_tool)
90+
91+
self._on_invoke_tool = on_invoke_tool
92+
93+
@classmethod
94+
def _deserialize_callable(
95+
cls, serialized: str
96+
) -> Callable[[RunContextWrapper, str], Any]:
97+
encoded = serialized.encode()
98+
serialized_bytes = base64.b64decode(encoded)
99+
return cloudpickle.loads(serialized_bytes)
100+
101+
@classmethod
102+
def _serialize_callable(cls, func: Callable) -> str:
103+
serialized_bytes = cloudpickle.dumps(func)
104+
encoded = base64.b64encode(serialized_bytes)
105+
return encoded.decode()
106+
107+
@property
108+
def on_invoke_tool(self) -> Callable[[RunContextWrapper, str], Any]:
109+
if self._on_invoke_tool is None and self.on_invoke_tool_serialized:
110+
self._on_invoke_tool = self._deserialize_callable(
111+
self.on_invoke_tool_serialized
112+
)
113+
return self._on_invoke_tool
114+
115+
@on_invoke_tool.setter
116+
def on_invoke_tool(self, value: Callable[[RunContextWrapper, str], Any]):
117+
self.on_invoke_tool_serialized = self._serialize_callable(value)
118+
self._on_invoke_tool = value
119+
48120
def to_oai_function_tool(self) -> OAIFunctionTool:
49-
return OAIFunctionTool(**self.model_dump(exclude=["trace_id", "parent_span_id"]))
121+
"""Convert to OpenAI function tool, excluding serialization fields."""
122+
# Create a dictionary with only the fields OAIFunctionTool expects
123+
data = self.model_dump(
124+
exclude={
125+
"trace_id",
126+
"parent_span_id",
127+
"_on_invoke_tool",
128+
"on_invoke_tool_serialized",
129+
}
130+
)
131+
# Add the callable for OAI tool since properties are not serialized
132+
data["on_invoke_tool"] = self.on_invoke_tool
133+
return OAIFunctionTool(**data)
50134

51135

52136
class ModelSettings(BaseModelWithTraceParams):
@@ -68,7 +152,9 @@ class ModelSettings(BaseModelWithTraceParams):
68152
extra_args: dict[str, Any] | None = None
69153

70154
def to_oai_model_settings(self) -> OAIModelSettings:
71-
return OAIModelSettings(**self.model_dump(exclude=["trace_id", "parent_span_id"]))
155+
return OAIModelSettings(
156+
**self.model_dump(exclude=["trace_id", "parent_span_id"])
157+
)
72158

73159

74160
class RunAgentParams(BaseModelWithTraceParams):

0 commit comments

Comments
 (0)