|
17 | 17 | ItemHelpers, |
18 | 18 | MessageOutputItem, |
19 | 19 | Model, |
| 20 | + ModelProvider, |
20 | 21 | ModelResponse, |
21 | 22 | ModelSettings, |
22 | 23 | ModelTracing, |
| 24 | + OpenAIChatCompletionsModel, |
23 | 25 | OpenAIResponsesModel, |
24 | 26 | OutputGuardrailTripwireTriggered, |
25 | 27 | RunContextWrapper, |
@@ -1844,6 +1846,40 @@ async def test_exception_handling(client: Client): |
1844 | 1846 | await assert_status_retry_behavior(404, client, should_retry=False) |
1845 | 1847 |
|
1846 | 1848 |
|
| 1849 | +class CustomModelProvider(ModelProvider): |
| 1850 | + def get_model(self, model_name: Optional[str]) -> Model: |
| 1851 | + client = AsyncOpenAI(base_url="https://api.openai.com/v1") |
| 1852 | + return OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client) |
| 1853 | + |
| 1854 | + |
| 1855 | +async def test_chat_completions_model(client: Client): |
| 1856 | + if not os.environ.get("OPENAI_API_KEY"): |
| 1857 | + pytest.skip("No openai API key") |
| 1858 | + |
| 1859 | + new_config = client.config() |
| 1860 | + new_config["plugins"] = [ |
| 1861 | + openai_agents.OpenAIAgentsPlugin( |
| 1862 | + model_params=ModelActivityParameters( |
| 1863 | + start_to_close_timeout=timedelta(seconds=30) |
| 1864 | + ), |
| 1865 | + model_provider=CustomModelProvider(), |
| 1866 | + ) |
| 1867 | + ] |
| 1868 | + client = Client(**new_config) |
| 1869 | + |
| 1870 | + async with new_worker( |
| 1871 | + client, |
| 1872 | + WorkflowToolWorkflow, |
| 1873 | + ) as worker: |
| 1874 | + workflow_handle = await client.start_workflow( |
| 1875 | + WorkflowToolWorkflow.run, |
| 1876 | + id=f"workflow-tool-{uuid.uuid4()}", |
| 1877 | + task_queue=worker.task_queue, |
| 1878 | + execution_timeout=timedelta(seconds=10), |
| 1879 | + ) |
| 1880 | + await workflow_handle.result() |
| 1881 | + |
| 1882 | + |
1847 | 1883 | class WaitModel(Model): |
1848 | 1884 | async def get_response( |
1849 | 1885 | self, |
|
0 commit comments