Skip to content

Commit b9842f4

Browse files
committed
WIP - additional tool types and LiteLLM
1 parent 9e7dc7a commit b9842f4

File tree

5 files changed

+1269
-24
lines changed

5 files changed

+1269
-24
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dev = [
5757
"pytest-cov>=6.1.1",
5858
"httpx>=0.28.1",
5959
"pytest-pretty>=1.3.0",
60+
"openai-agents[litellm] >= 0.2.3,<0.3"
6061
]
6162

6263
[tool.poe.tasks]

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Tool,
2222
TResponseInputItem,
2323
UserError,
24-
WebSearchTool,
24+
WebSearchTool, ImageGenerationTool, CodeInterpreterTool,
2525
)
2626
from agents.models.multi_provider import MultiProvider
2727
from typing_extensions import Required, TypedDict
@@ -51,7 +51,7 @@ class FunctionToolInput:
5151
strict_json_schema: bool = True
5252

5353

54-
ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool]
54+
ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool]
5555

5656

5757
@dataclass
@@ -143,10 +143,8 @@ async def empty_on_invoke_handoff(
143143
input_input = json.loads(input_json)
144144

145145
def make_tool(tool: ToolInput) -> Tool:
146-
if isinstance(tool, FileSearchTool):
147-
return cast(FileSearchTool, tool)
148-
elif isinstance(tool, WebSearchTool):
149-
return cast(WebSearchTool, tool)
146+
if isinstance(tool, (FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool)):
147+
return cast(Tool, tool)
150148
elif isinstance(tool, FunctionToolInput):
151149
t = cast(FunctionToolInput, tool)
152150
return FunctionTool(

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ModelTracing,
2424
Tool,
2525
TResponseInputItem,
26-
WebSearchTool,
26+
WebSearchTool, ImageGenerationTool, CodeInterpreterTool,
2727
)
2828
from agents.items import TResponseStreamEvent
2929
from openai.types.responses.response_prompt_param import ResponsePromptParam
@@ -87,12 +87,8 @@ def get_summary(
8787
return ""
8888

8989
def make_tool_info(tool: Tool) -> ToolInput:
90-
if isinstance(tool, (FileSearchTool, WebSearchTool)):
90+
if isinstance(tool, (FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool)):
9191
return tool
92-
elif isinstance(tool, ComputerTool):
93-
raise NotImplementedError(
94-
"Computer search preview is not supported in Temporal model"
95-
)
9692
elif isinstance(tool, FunctionTool):
9793
return FunctionToolInput(
9894
name=tool.name,
@@ -101,7 +97,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
10197
strict_json_schema=tool.strict_json_schema,
10298
)
10399
else:
104-
raise ValueError(f"Unknown tool type: {tool.name}")
100+
raise ValueError(f"Unsupported tool type: {tool.name}")
105101

106102
tool_infos = [make_tool_info(x) for x in tools]
107103
handoff_infos = [

tests/contrib/openai_agents/test_openai.py

Lines changed: 290 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
handoff,
3030
input_guardrail,
3131
output_guardrail,
32-
trace,
32+
trace, FileSearchTool, ImageGenerationTool, CodeInterpreterTool,
3333
)
3434
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
35+
from agents.extensions.models.litellm_provider import LitellmProvider
3536
from agents.items import (
3637
HandoffOutputItem,
3738
ToolCallItem,
@@ -42,10 +43,12 @@
4243
ResponseFunctionToolCall,
4344
ResponseFunctionWebSearch,
4445
ResponseOutputMessage,
45-
ResponseOutputText,
46+
ResponseOutputText, ResponseFileSearchToolCall, ResponseCodeInterpreterToolCall,
4647
)
48+
from openai.types.responses.response_file_search_tool_call import Result
4749
from openai.types.responses.response_function_web_search import ActionSearch
4850
from openai.types.responses.response_prompt_param import ResponsePromptParam
51+
from openai.types.responses.tool_param import ImageGeneration
4952
from pydantic import ConfigDict, Field, TypeAdapter
5053

5154
from temporalio import activity, workflow
@@ -1777,3 +1780,288 @@ async def test_response_serialization():
17771780
response_id="",
17781781
)
17791782
encoded = await pydantic_data_converter.encode([model_response])
1783+
1784+
async def test_lite_llm(client: Client):
1785+
if not os.environ.get("OPENAI_API_KEY"):
1786+
pytest.skip("No openai API key")
1787+
new_config = client.config()
1788+
new_config["plugins"] = [
1789+
openai_agents.OpenAIAgentsPlugin(
1790+
model_params=ModelActivityParameters(
1791+
start_to_close_timeout=timedelta(seconds=30)
1792+
),
1793+
model_provider=LitellmProvider(),
1794+
)
1795+
]
1796+
client = Client(**new_config)
1797+
1798+
async with new_worker(
1799+
client,
1800+
HelloWorldAgent,
1801+
) as worker:
1802+
workflow_handle = await client.start_workflow(
1803+
HelloWorldAgent.run,
1804+
"Tell me about recursion in programming",
1805+
id=f"lite-llm-{uuid.uuid4()}",
1806+
task_queue=worker.task_queue,
1807+
execution_timeout=timedelta(seconds=10),
1808+
)
1809+
await workflow_handle.result()
1810+
1811+
1812+
class FileSearchToolModel(StaticTestModel):
1813+
responses = [
1814+
ModelResponse(
1815+
output=[
1816+
ResponseFileSearchToolCall(
1817+
queries=["side character in the Iliad"],
1818+
type="file_search_call",
1819+
id="id",
1820+
status="completed",
1821+
results=[
1822+
Result(text="Some scene"),
1823+
Result(text="Other scene"),
1824+
]
1825+
),
1826+
ResponseOutputMessage(
1827+
id="",
1828+
content=[
1829+
ResponseOutputText(
1830+
text="Patroclus",
1831+
annotations=[],
1832+
type="output_text",
1833+
)
1834+
],
1835+
role="assistant",
1836+
status="completed",
1837+
type="message",
1838+
)
1839+
],
1840+
usage=Usage(),
1841+
response_id=None,
1842+
),
1843+
]
1844+
1845+
@workflow.defn
1846+
class FileSearchToolWorkflow:
1847+
@workflow.run
1848+
async def run(self, question: str) -> str:
1849+
agent = Agent[str](
1850+
name="File Search Workflow",
1851+
instructions="You are a librarian. You should use your tools to source all your information.",
1852+
tools=[
1853+
FileSearchTool(
1854+
max_num_results=3,
1855+
vector_store_ids=["vs_687fd7f5e69c8191a2740f06bc9a159d"],
1856+
include_search_results=True,
1857+
)
1858+
],
1859+
)
1860+
result = await Runner.run(
1861+
starting_agent=agent, input=question
1862+
)
1863+
1864+
# A file search was performed
1865+
assert any(isinstance(item, ToolCallItem) and isinstance(item.raw_item, ResponseFileSearchToolCall) for item in result.new_items)
1866+
return result.final_output
1867+
1868+
@pytest.mark.parametrize("use_local_model", [True, False])
1869+
async def test_file_search_tool(client: Client, use_local_model):
1870+
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
1871+
pytest.skip("No openai API key")
1872+
1873+
new_config = client.config()
1874+
new_config["plugins"] = [
1875+
openai_agents.OpenAIAgentsPlugin(
1876+
model_params=ModelActivityParameters(
1877+
start_to_close_timeout=timedelta(seconds=30)
1878+
),
1879+
model_provider=TestModelProvider(FileSearchToolModel())
1880+
if use_local_model
1881+
else None,
1882+
)
1883+
]
1884+
client = Client(**new_config)
1885+
1886+
async with new_worker(
1887+
client,
1888+
FileSearchToolWorkflow,
1889+
) as worker:
1890+
workflow_handle = await client.start_workflow(
1891+
FileSearchToolWorkflow.run,
1892+
"Tell me about a side character in the Iliad.",
1893+
id=f"file-search-tool-{uuid.uuid4()}",
1894+
task_queue=worker.task_queue,
1895+
execution_timeout=timedelta(seconds=30),
1896+
)
1897+
result = await workflow_handle.result()
1898+
if use_local_model:
1899+
assert result == "Patroclus"
1900+
1901+
1902+
class ImageGenerationModel(StaticTestModel):
1903+
responses = [
1904+
ModelResponse(
1905+
output=[
1906+
ResponseFileSearchToolCall(
1907+
queries=["side character in the Iliad"],
1908+
type="file_search_call",
1909+
id="id",
1910+
status="completed",
1911+
results=[
1912+
Result(text="Some scene"),
1913+
Result(text="Other scene"),
1914+
]
1915+
),
1916+
ResponseOutputMessage(
1917+
id="",
1918+
content=[
1919+
ResponseOutputText(
1920+
text="Patroclus",
1921+
annotations=[],
1922+
type="output_text",
1923+
)
1924+
],
1925+
role="assistant",
1926+
status="completed",
1927+
type="message",
1928+
)
1929+
],
1930+
usage=Usage(),
1931+
response_id=None,
1932+
),
1933+
]
1934+
1935+
@workflow.defn
1936+
class ImageGenerationWorkflow:
1937+
@workflow.run
1938+
async def run(self, question: str) -> str:
1939+
agent = Agent[str](
1940+
name="Image Generation Workflow",
1941+
instructions="You are a helpful agent.",
1942+
tools=[
1943+
ImageGenerationTool(
1944+
tool_config={"type": "image_generation", "quality": "low"},
1945+
)
1946+
],
1947+
)
1948+
result = await Runner.run(
1949+
starting_agent=agent, input=question
1950+
)
1951+
1952+
return result.final_output
1953+
1954+
@pytest.mark.parametrize("use_local_model", [True, False])
1955+
async def test_image_generation_tool(client: Client, use_local_model):
1956+
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
1957+
pytest.skip("No openai API key")
1958+
1959+
new_config = client.config()
1960+
new_config["plugins"] = [
1961+
openai_agents.OpenAIAgentsPlugin(
1962+
model_params=ModelActivityParameters(
1963+
start_to_close_timeout=timedelta(seconds=30)
1964+
),
1965+
model_provider=TestModelProvider(ImageGenerationModel())
1966+
if use_local_model
1967+
else None,
1968+
)
1969+
]
1970+
client = Client(**new_config)
1971+
1972+
async with new_worker(
1973+
client,
1974+
ImageGenerationWorkflow,
1975+
) as worker:
1976+
workflow_handle = await client.start_workflow(
1977+
ImageGenerationWorkflow.run,
1978+
"Create an image of a frog eating a pizza, comic book style.",
1979+
id=f"image-generation-tool-{uuid.uuid4()}",
1980+
task_queue=worker.task_queue,
1981+
execution_timeout=timedelta(seconds=30),
1982+
)
1983+
result = await workflow_handle.result()
1984+
1985+
1986+
class CodeInterpreterModel(StaticTestModel):
1987+
responses = [
1988+
ModelResponse(
1989+
output=[
1990+
ResponseCodeInterpreterToolCall(
1991+
container_id="",
1992+
code="some code",
1993+
type="code_interpreter_call",
1994+
id="id",
1995+
status="completed",
1996+
),
1997+
ResponseOutputMessage(
1998+
id="",
1999+
content=[
2000+
ResponseOutputText(
2001+
text="Over 9000",
2002+
annotations=[],
2003+
type="output_text",
2004+
)
2005+
],
2006+
role="assistant",
2007+
status="completed",
2008+
type="message",
2009+
)
2010+
],
2011+
usage=Usage(),
2012+
response_id=None,
2013+
),
2014+
]
2015+
2016+
@workflow.defn
2017+
class CodeInterpreterWorkflow:
2018+
@workflow.run
2019+
async def run(self, question: str) -> str:
2020+
agent = Agent[str](
2021+
name="Code Interpreter Workflow",
2022+
instructions="You are a helpful agent.",
2023+
tools=[
2024+
CodeInterpreterTool(
2025+
tool_config={"type": "code_interpreter", "container": {"type": "auto"}},
2026+
)
2027+
],
2028+
)
2029+
result = await Runner.run(
2030+
starting_agent=agent, input=question
2031+
)
2032+
2033+
assert any(isinstance(item, ToolCallItem) and isinstance(item.raw_item, ResponseCodeInterpreterToolCall) for item in result.new_items)
2034+
return result.final_output
2035+
2036+
@pytest.mark.parametrize("use_local_model", [True, False])
2037+
async def test_code_interpreter_tool(client: Client, use_local_model):
2038+
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
2039+
pytest.skip("No openai API key")
2040+
2041+
new_config = client.config()
2042+
new_config["plugins"] = [
2043+
openai_agents.OpenAIAgentsPlugin(
2044+
model_params=ModelActivityParameters(
2045+
start_to_close_timeout=timedelta(seconds=30)
2046+
),
2047+
model_provider=TestModelProvider(CodeInterpreterModel())
2048+
if use_local_model
2049+
else None,
2050+
)
2051+
]
2052+
client = Client(**new_config)
2053+
2054+
async with new_worker(
2055+
client,
2056+
CodeInterpreterWorkflow,
2057+
) as worker:
2058+
workflow_handle = await client.start_workflow(
2059+
CodeInterpreterWorkflow.run,
2060+
"What is the square root of273 * 312821 plus 1782?",
2061+
id=f"code-interpreter-tool-{uuid.uuid4()}",
2062+
task_queue=worker.task_queue,
2063+
execution_timeout=timedelta(seconds=30),
2064+
)
2065+
result = await workflow_handle.result()
2066+
if use_local_model:
2067+
assert result == "Over 9000"

0 commit comments

Comments
 (0)