|
29 | 29 | handoff, |
30 | 30 | input_guardrail, |
31 | 31 | output_guardrail, |
32 | | - trace, |
| 32 | + trace, FileSearchTool, ImageGenerationTool, CodeInterpreterTool, |
33 | 33 | ) |
34 | 34 | from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX |
| 35 | +from agents.extensions.models.litellm_provider import LitellmProvider |
35 | 36 | from agents.items import ( |
36 | 37 | HandoffOutputItem, |
37 | 38 | ToolCallItem, |
|
42 | 43 | ResponseFunctionToolCall, |
43 | 44 | ResponseFunctionWebSearch, |
44 | 45 | ResponseOutputMessage, |
45 | | - ResponseOutputText, |
| 46 | + ResponseOutputText, ResponseFileSearchToolCall, ResponseCodeInterpreterToolCall, |
46 | 47 | ) |
| 48 | +from openai.types.responses.response_file_search_tool_call import Result |
47 | 49 | from openai.types.responses.response_function_web_search import ActionSearch |
48 | 50 | from openai.types.responses.response_prompt_param import ResponsePromptParam |
| 51 | +from openai.types.responses.tool_param import ImageGeneration |
49 | 52 | from pydantic import ConfigDict, Field, TypeAdapter |
50 | 53 |
|
51 | 54 | from temporalio import activity, workflow |
@@ -1777,3 +1780,288 @@ async def test_response_serialization(): |
1777 | 1780 | response_id="", |
1778 | 1781 | ) |
1779 | 1782 | encoded = await pydantic_data_converter.encode([model_response]) |
| 1783 | + |
| 1784 | +async def test_lite_llm(client: Client): |
| 1785 | + if not os.environ.get("OPENAI_API_KEY"): |
| 1786 | + pytest.skip("No openai API key") |
| 1787 | + new_config = client.config() |
| 1788 | + new_config["plugins"] = [ |
| 1789 | + openai_agents.OpenAIAgentsPlugin( |
| 1790 | + model_params=ModelActivityParameters( |
| 1791 | + start_to_close_timeout=timedelta(seconds=30) |
| 1792 | + ), |
| 1793 | + model_provider=LitellmProvider(), |
| 1794 | + ) |
| 1795 | + ] |
| 1796 | + client = Client(**new_config) |
| 1797 | + |
| 1798 | + async with new_worker( |
| 1799 | + client, |
| 1800 | + HelloWorldAgent, |
| 1801 | + ) as worker: |
| 1802 | + workflow_handle = await client.start_workflow( |
| 1803 | + HelloWorldAgent.run, |
| 1804 | + "Tell me about recursion in programming", |
| 1805 | + id=f"lite-llm-{uuid.uuid4()}", |
| 1806 | + task_queue=worker.task_queue, |
| 1807 | + execution_timeout=timedelta(seconds=10), |
| 1808 | + ) |
| 1809 | + await workflow_handle.result() |
| 1810 | + |
| 1811 | + |
| 1812 | +class FileSearchToolModel(StaticTestModel): |
| 1813 | + responses = [ |
| 1814 | + ModelResponse( |
| 1815 | + output=[ |
| 1816 | + ResponseFileSearchToolCall( |
| 1817 | + queries=["side character in the Iliad"], |
| 1818 | + type="file_search_call", |
| 1819 | + id="id", |
| 1820 | + status="completed", |
| 1821 | + results=[ |
| 1822 | + Result(text="Some scene"), |
| 1823 | + Result(text="Other scene"), |
| 1824 | + ] |
| 1825 | + ), |
| 1826 | + ResponseOutputMessage( |
| 1827 | + id="", |
| 1828 | + content=[ |
| 1829 | + ResponseOutputText( |
| 1830 | + text="Patroclus", |
| 1831 | + annotations=[], |
| 1832 | + type="output_text", |
| 1833 | + ) |
| 1834 | + ], |
| 1835 | + role="assistant", |
| 1836 | + status="completed", |
| 1837 | + type="message", |
| 1838 | + ) |
| 1839 | + ], |
| 1840 | + usage=Usage(), |
| 1841 | + response_id=None, |
| 1842 | + ), |
| 1843 | + ] |
| 1844 | + |
| 1845 | +@workflow.defn |
| 1846 | +class FileSearchToolWorkflow: |
| 1847 | + @workflow.run |
| 1848 | + async def run(self, question: str) -> str: |
| 1849 | + agent = Agent[str]( |
| 1850 | + name="File Search Workflow", |
| 1851 | + instructions="You are a librarian. You should use your tools to source all your information.", |
| 1852 | + tools=[ |
| 1853 | + FileSearchTool( |
| 1854 | + max_num_results=3, |
| 1855 | + vector_store_ids=["vs_687fd7f5e69c8191a2740f06bc9a159d"], |
| 1856 | + include_search_results=True, |
| 1857 | + ) |
| 1858 | + ], |
| 1859 | + ) |
| 1860 | + result = await Runner.run( |
| 1861 | + starting_agent=agent, input=question |
| 1862 | + ) |
| 1863 | + |
| 1864 | + # A file search was performed |
| 1865 | + assert any(isinstance(item, ToolCallItem) and isinstance(item.raw_item, ResponseFileSearchToolCall) for item in result.new_items) |
| 1866 | + return result.final_output |
| 1867 | + |
| 1868 | +@pytest.mark.parametrize("use_local_model", [True, False]) |
| 1869 | +async def test_file_search_tool(client: Client, use_local_model): |
| 1870 | + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): |
| 1871 | + pytest.skip("No openai API key") |
| 1872 | + |
| 1873 | + new_config = client.config() |
| 1874 | + new_config["plugins"] = [ |
| 1875 | + openai_agents.OpenAIAgentsPlugin( |
| 1876 | + model_params=ModelActivityParameters( |
| 1877 | + start_to_close_timeout=timedelta(seconds=30) |
| 1878 | + ), |
| 1879 | + model_provider=TestModelProvider(FileSearchToolModel()) |
| 1880 | + if use_local_model |
| 1881 | + else None, |
| 1882 | + ) |
| 1883 | + ] |
| 1884 | + client = Client(**new_config) |
| 1885 | + |
| 1886 | + async with new_worker( |
| 1887 | + client, |
| 1888 | + FileSearchToolWorkflow, |
| 1889 | + ) as worker: |
| 1890 | + workflow_handle = await client.start_workflow( |
| 1891 | + FileSearchToolWorkflow.run, |
| 1892 | + "Tell me about a side character in the Iliad.", |
| 1893 | + id=f"file-search-tool-{uuid.uuid4()}", |
| 1894 | + task_queue=worker.task_queue, |
| 1895 | + execution_timeout=timedelta(seconds=30), |
| 1896 | + ) |
| 1897 | + result = await workflow_handle.result() |
| 1898 | + if use_local_model: |
| 1899 | + assert result == "Patroclus" |
| 1900 | + |
| 1901 | + |
| 1902 | +class ImageGenerationModel(StaticTestModel): |
| 1903 | + responses = [ |
| 1904 | + ModelResponse( |
| 1905 | + output=[ |
| 1906 | + ResponseFileSearchToolCall( |
| 1907 | + queries=["side character in the Iliad"], |
| 1908 | + type="file_search_call", |
| 1909 | + id="id", |
| 1910 | + status="completed", |
| 1911 | + results=[ |
| 1912 | + Result(text="Some scene"), |
| 1913 | + Result(text="Other scene"), |
| 1914 | + ] |
| 1915 | + ), |
| 1916 | + ResponseOutputMessage( |
| 1917 | + id="", |
| 1918 | + content=[ |
| 1919 | + ResponseOutputText( |
| 1920 | + text="Patroclus", |
| 1921 | + annotations=[], |
| 1922 | + type="output_text", |
| 1923 | + ) |
| 1924 | + ], |
| 1925 | + role="assistant", |
| 1926 | + status="completed", |
| 1927 | + type="message", |
| 1928 | + ) |
| 1929 | + ], |
| 1930 | + usage=Usage(), |
| 1931 | + response_id=None, |
| 1932 | + ), |
| 1933 | + ] |
| 1934 | + |
| 1935 | +@workflow.defn |
| 1936 | +class ImageGenerationWorkflow: |
| 1937 | + @workflow.run |
| 1938 | + async def run(self, question: str) -> str: |
| 1939 | + agent = Agent[str]( |
| 1940 | + name="Image Generation Workflow", |
| 1941 | + instructions="You are a helpful agent.", |
| 1942 | + tools=[ |
| 1943 | + ImageGenerationTool( |
| 1944 | + tool_config={"type": "image_generation", "quality": "low"}, |
| 1945 | + ) |
| 1946 | + ], |
| 1947 | + ) |
| 1948 | + result = await Runner.run( |
| 1949 | + starting_agent=agent, input=question |
| 1950 | + ) |
| 1951 | + |
| 1952 | + return result.final_output |
| 1953 | + |
| 1954 | +@pytest.mark.parametrize("use_local_model", [True, False]) |
| 1955 | +async def test_image_generation_tool(client: Client, use_local_model): |
| 1956 | + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): |
| 1957 | + pytest.skip("No openai API key") |
| 1958 | + |
| 1959 | + new_config = client.config() |
| 1960 | + new_config["plugins"] = [ |
| 1961 | + openai_agents.OpenAIAgentsPlugin( |
| 1962 | + model_params=ModelActivityParameters( |
| 1963 | + start_to_close_timeout=timedelta(seconds=30) |
| 1964 | + ), |
| 1965 | + model_provider=TestModelProvider(ImageGenerationModel()) |
| 1966 | + if use_local_model |
| 1967 | + else None, |
| 1968 | + ) |
| 1969 | + ] |
| 1970 | + client = Client(**new_config) |
| 1971 | + |
| 1972 | + async with new_worker( |
| 1973 | + client, |
| 1974 | + ImageGenerationWorkflow, |
| 1975 | + ) as worker: |
| 1976 | + workflow_handle = await client.start_workflow( |
| 1977 | + ImageGenerationWorkflow.run, |
| 1978 | + "Create an image of a frog eating a pizza, comic book style.", |
| 1979 | + id=f"image-generation-tool-{uuid.uuid4()}", |
| 1980 | + task_queue=worker.task_queue, |
| 1981 | + execution_timeout=timedelta(seconds=30), |
| 1982 | + ) |
| 1983 | + result = await workflow_handle.result() |
| 1984 | + |
| 1985 | + |
| 1986 | +class CodeInterpreterModel(StaticTestModel): |
| 1987 | + responses = [ |
| 1988 | + ModelResponse( |
| 1989 | + output=[ |
| 1990 | + ResponseCodeInterpreterToolCall( |
| 1991 | + container_id="", |
| 1992 | + code="some code", |
| 1993 | + type="code_interpreter_call", |
| 1994 | + id="id", |
| 1995 | + status="completed", |
| 1996 | + ), |
| 1997 | + ResponseOutputMessage( |
| 1998 | + id="", |
| 1999 | + content=[ |
| 2000 | + ResponseOutputText( |
| 2001 | + text="Over 9000", |
| 2002 | + annotations=[], |
| 2003 | + type="output_text", |
| 2004 | + ) |
| 2005 | + ], |
| 2006 | + role="assistant", |
| 2007 | + status="completed", |
| 2008 | + type="message", |
| 2009 | + ) |
| 2010 | + ], |
| 2011 | + usage=Usage(), |
| 2012 | + response_id=None, |
| 2013 | + ), |
| 2014 | + ] |
| 2015 | + |
| 2016 | +@workflow.defn |
| 2017 | +class CodeInterpreterWorkflow: |
| 2018 | + @workflow.run |
| 2019 | + async def run(self, question: str) -> str: |
| 2020 | + agent = Agent[str]( |
| 2021 | + name="Code Interpreter Workflow", |
| 2022 | + instructions="You are a helpful agent.", |
| 2023 | + tools=[ |
| 2024 | + CodeInterpreterTool( |
| 2025 | + tool_config={"type": "code_interpreter", "container": {"type": "auto"}}, |
| 2026 | + ) |
| 2027 | + ], |
| 2028 | + ) |
| 2029 | + result = await Runner.run( |
| 2030 | + starting_agent=agent, input=question |
| 2031 | + ) |
| 2032 | + |
| 2033 | + assert any(isinstance(item, ToolCallItem) and isinstance(item.raw_item, ResponseCodeInterpreterToolCall) for item in result.new_items) |
| 2034 | + return result.final_output |
| 2035 | + |
| 2036 | +@pytest.mark.parametrize("use_local_model", [True, False]) |
| 2037 | +async def test_code_interpreter_tool(client: Client, use_local_model): |
| 2038 | + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): |
| 2039 | + pytest.skip("No openai API key") |
| 2040 | + |
| 2041 | + new_config = client.config() |
| 2042 | + new_config["plugins"] = [ |
| 2043 | + openai_agents.OpenAIAgentsPlugin( |
| 2044 | + model_params=ModelActivityParameters( |
| 2045 | + start_to_close_timeout=timedelta(seconds=30) |
| 2046 | + ), |
| 2047 | + model_provider=TestModelProvider(CodeInterpreterModel()) |
| 2048 | + if use_local_model |
| 2049 | + else None, |
| 2050 | + ) |
| 2051 | + ] |
| 2052 | + client = Client(**new_config) |
| 2053 | + |
| 2054 | + async with new_worker( |
| 2055 | + client, |
| 2056 | + CodeInterpreterWorkflow, |
| 2057 | + ) as worker: |
| 2058 | + workflow_handle = await client.start_workflow( |
| 2059 | + CodeInterpreterWorkflow.run, |
| 2060 | + "What is the square root of273 * 312821 plus 1782?", |
| 2061 | + id=f"code-interpreter-tool-{uuid.uuid4()}", |
| 2062 | + task_queue=worker.task_queue, |
| 2063 | + execution_timeout=timedelta(seconds=30), |
| 2064 | + ) |
| 2065 | + result = await workflow_handle.result() |
| 2066 | + if use_local_model: |
| 2067 | + assert result == "Over 9000" |
0 commit comments