Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ dependencies = [
"pytest-asyncio>=1.0.0",
"scale-gp-beta==0.1.0a20",
"ipykernel>=6.29.5",
"openai>=1.99.9",
"openai==1.99.9", # anything higher than 1.99.9 breaks litellm - https://github.com/BerriAI/litellm/issues/13711
"cloudpickle>=3.1.1",
]
requires-python = ">= 3.12,<4"
classifiers = [
Expand Down
5 changes: 4 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ click==8.2.1
# via litellm
# via typer
# via uvicorn
cloudpickle==3.1.1
# via agentex-sdk
colorama==0.4.6
# via griffe
colorlog==6.7.0
Expand Down Expand Up @@ -179,9 +181,10 @@ oauthlib==3.3.1
# via kubernetes
# via requests-oauthlib
openai==1.99.9
# via agentex-sdk
# via litellm
# via openai-agents
openai-agents==0.2.6
openai-agents==0.2.7
# via agentex-sdk
packaging==23.2
# via huggingface-hub
Expand Down
5 changes: 4 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ click==8.2.1
# via litellm
# via typer
# via uvicorn
cloudpickle==3.1.1
# via agentex-sdk
colorama==0.4.6
# via griffe
comm==0.2.3
Expand Down Expand Up @@ -162,9 +164,10 @@ oauthlib==3.3.1
# via kubernetes
# via requests-oauthlib
openai==1.99.9
# via agentex-sdk
# via litellm
# via openai-agents
openai-agents==0.2.6
openai-agents==0.2.7
# via agentex-sdk
packaging==25.0
# via huggingface-hub
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Standard library imports
import base64
from collections.abc import Callable
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum
from typing import Any, Literal
from typing import Any, Literal, Optional, override

from pydantic import Field, PrivateAttr

import cloudpickle
from agents import RunContextWrapper, RunResult, RunResultStreaming
from agents.mcp import MCPServerStdio, MCPServerStdioParams
from agents.model_settings import ModelSettings as OAIModelSettings
Expand Down Expand Up @@ -41,12 +45,92 @@ class FunctionTool(BaseModelWithTraceParams):
name: str
description: str
params_json_schema: dict[str, Any]
on_invoke_tool: Callable[[RunContextWrapper, str], Any]

strict_json_schema: bool = True
is_enabled: bool = True

_on_invoke_tool: Callable[[RunContextWrapper, str], Any] = PrivateAttr()
on_invoke_tool_serialized: str = Field(
default="",
description=(
"Normally will be set automatically during initialization and"
" doesn't need to be passed. "
"Instead, pass `on_invoke_tool` to the constructor. "
"See the __init__ method for details."
),
)

def __init__(
self,
*,
on_invoke_tool: Optional[Callable[[RunContextWrapper, str], Any]] = None,
**data,
):
"""
Initialize a FunctionTool with hacks to support serialization of the
on_invoke_tool callable arg. This is required to facilitate over-the-wire
communication of this object to/from temporal services/workers.

Args:
on_invoke_tool: The callable to invoke when the tool is called.
**data: Additional data to initialize the FunctionTool.
"""
super().__init__(**data)
if not on_invoke_tool:
if not self.on_invoke_tool_serialized:
raise ValueError(
"One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set"
)
else:
on_invoke_tool = self._deserialize_callable(
self.on_invoke_tool_serialized
)
else:
self.on_invoke_tool_serialized = self._serialize_callable(on_invoke_tool)

self._on_invoke_tool = on_invoke_tool

@classmethod
def _deserialize_callable(
cls, serialized: str
) -> Callable[[RunContextWrapper, str], Any]:
encoded = serialized.encode()
serialized_bytes = base64.b64decode(encoded)
return cloudpickle.loads(serialized_bytes)

@classmethod
def _serialize_callable(cls, func: Callable) -> str:
serialized_bytes = cloudpickle.dumps(func)
encoded = base64.b64encode(serialized_bytes)
return encoded.decode()

@property
def on_invoke_tool(self) -> Callable[[RunContextWrapper, str], Any]:
if self._on_invoke_tool is None and self.on_invoke_tool_serialized:
self._on_invoke_tool = self._deserialize_callable(
self.on_invoke_tool_serialized
)
return self._on_invoke_tool

@on_invoke_tool.setter
def on_invoke_tool(self, value: Callable[[RunContextWrapper, str], Any]):
self.on_invoke_tool_serialized = self._serialize_callable(value)
self._on_invoke_tool = value

def to_oai_function_tool(self) -> OAIFunctionTool:
return OAIFunctionTool(**self.model_dump(exclude=["trace_id", "parent_span_id"]))
"""Convert to OpenAI function tool, excluding serialization fields."""
# Create a dictionary with only the fields OAIFunctionTool expects
data = self.model_dump(
exclude={
"trace_id",
"parent_span_id",
"_on_invoke_tool",
"on_invoke_tool_serialized",
}
)
# Add the callable for OAI tool since properties are not serialized
data["on_invoke_tool"] = self.on_invoke_tool
return OAIFunctionTool(**data)


class ModelSettings(BaseModelWithTraceParams):
Expand All @@ -68,7 +152,9 @@ class ModelSettings(BaseModelWithTraceParams):
extra_args: dict[str, Any] | None = None

def to_oai_model_settings(self) -> OAIModelSettings:
return OAIModelSettings(**self.model_dump(exclude=["trace_id", "parent_span_id"]))
return OAIModelSettings(
**self.model_dump(exclude=["trace_id", "parent_span_id"])
)


class RunAgentParams(BaseModelWithTraceParams):
Expand Down
Loading
Loading