|
15 | 15 | InputGuardrailTripwireTriggered, |
16 | 16 | ItemHelpers, |
17 | 17 | MessageOutputItem, |
| 18 | + Model, |
| 19 | + ModelProvider, |
18 | 20 | ModelResponse, |
19 | 21 | ModelSettings, |
20 | 22 | ModelTracing, |
| 23 | + OpenAIChatCompletionsModel, |
21 | 24 | OpenAIResponsesModel, |
22 | 25 | OutputGuardrailTripwireTriggered, |
23 | 26 | RunContextWrapper, |
@@ -1839,3 +1842,37 @@ async def test_exception_handling(client: Client): |
1839 | 1842 | await assert_status_retry_behavior(400, client, should_retry=False) |
1840 | 1843 | await assert_status_retry_behavior(403, client, should_retry=False) |
1841 | 1844 | await assert_status_retry_behavior(404, client, should_retry=False) |
| 1845 | + |
| 1846 | + |
| 1847 | +class CustomModelProvider(ModelProvider): |
| 1848 | + def get_model(self, model_name: Optional[str]) -> Model: |
| 1849 | + client = AsyncOpenAI(base_url="https://api.openai.com/v1") |
| 1850 | + return OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client) |
| 1851 | + |
| 1852 | + |
| 1853 | +async def test_chat_completions_model(client: Client): |
| 1854 | + if not os.environ.get("OPENAI_API_KEY"): |
| 1855 | + pytest.skip("No openai API key") |
| 1856 | + |
| 1857 | + new_config = client.config() |
| 1858 | + new_config["plugins"] = [ |
| 1859 | + openai_agents.OpenAIAgentsPlugin( |
| 1860 | + model_params=ModelActivityParameters( |
| 1861 | + start_to_close_timeout=timedelta(seconds=30) |
| 1862 | + ), |
| 1863 | + model_provider=CustomModelProvider(), |
| 1864 | + ) |
| 1865 | + ] |
| 1866 | + client = Client(**new_config) |
| 1867 | + |
| 1868 | + async with new_worker( |
| 1869 | + client, |
| 1870 | + WorkflowToolWorkflow, |
| 1871 | + ) as worker: |
| 1872 | + workflow_handle = await client.start_workflow( |
| 1873 | + WorkflowToolWorkflow.run, |
| 1874 | + id=f"workflow-tool-{uuid.uuid4()}", |
| 1875 | + task_queue=worker.task_queue, |
| 1876 | + execution_timeout=timedelta(seconds=10), |
| 1877 | + ) |
| 1878 | + await workflow_handle.result() |
0 commit comments