|
28 | 28 | handoff, |
29 | 29 | input_guardrail, |
30 | 30 | output_guardrail, |
31 | | - trace, |
| 31 | + trace, ModelProvider, Model, OpenAIChatCompletionsModel, |
32 | 32 | ) |
33 | 33 | from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX |
34 | 34 | from agents.items import ( |
@@ -1778,3 +1778,32 @@ async def test_workflow_method_tools(client: Client): |
1778 | 1778 | execution_timeout=timedelta(seconds=10), |
1779 | 1779 | ) |
1780 | 1780 | await workflow_handle.result() |
| 1781 | + |
| 1782 | +class CustomModelProvider(ModelProvider): |
| 1783 | + def get_model(self, model_name: str) -> Model: |
| 1784 | + client = AsyncOpenAI(base_url="https://api.openai.com/v1") |
| 1785 | + return OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client) |
| 1786 | + |
| 1787 | +async def test_chat_completions_model(client: Client): |
| 1788 | + if not os.environ.get("OPENAI_API_KEY"): |
| 1789 | + pytest.skip("No openai API key") |
| 1790 | + |
| 1791 | + new_config = client.config() |
| 1792 | + new_config["data_converter"] = pydantic_data_converter |
| 1793 | + client = Client(**new_config) |
| 1794 | + |
| 1795 | + with set_open_ai_agent_temporal_overrides(): |
| 1796 | + model_activity = ModelActivity(model_provider=CustomModelProvider()) |
| 1797 | + async with new_worker( |
| 1798 | + client, |
| 1799 | + WorkflowToolWorkflow, |
| 1800 | + activities=[model_activity.invoke_model_activity], |
| 1801 | + interceptors=[OpenAIAgentsTracingInterceptor()], |
| 1802 | + ) as worker: |
| 1803 | + workflow_handle = await client.start_workflow( |
| 1804 | + WorkflowToolWorkflow.run, |
| 1805 | + id=f"workflow-tool-{uuid.uuid4()}", |
| 1806 | + task_queue=worker.task_queue, |
| 1807 | + execution_timeout=timedelta(seconds=10), |
| 1808 | + ) |
| 1809 | + await workflow_handle.result() |
0 commit comments