Skip to content

Commit df1aa00

Browse files
committed
Add test for chat completions model (#986)
* Add test for chat completions model * lint
1 parent 1c5c105 commit df1aa00

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tests/contrib/openai_agents/test_openai.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
InputGuardrailTripwireTriggered,
1616
ItemHelpers,
1717
MessageOutputItem,
18+
Model,
19+
ModelProvider,
1820
ModelResponse,
1921
ModelSettings,
2022
ModelTracing,
23+
OpenAIChatCompletionsModel,
2124
OpenAIResponsesModel,
2225
OutputGuardrailTripwireTriggered,
2326
RunContextWrapper,
@@ -1839,3 +1842,37 @@ async def test_exception_handling(client: Client):
18391842
await assert_status_retry_behavior(400, client, should_retry=False)
18401843
await assert_status_retry_behavior(403, client, should_retry=False)
18411844
await assert_status_retry_behavior(404, client, should_retry=False)
1845+
1846+
1847+
class CustomModelProvider(ModelProvider):
1848+
def get_model(self, model_name: Optional[str]) -> Model:
1849+
client = AsyncOpenAI(base_url="https://api.openai.com/v1")
1850+
return OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client)
1851+
1852+
1853+
async def test_chat_completions_model(client: Client):
1854+
if not os.environ.get("OPENAI_API_KEY"):
1855+
pytest.skip("No openai API key")
1856+
1857+
new_config = client.config()
1858+
new_config["plugins"] = [
1859+
openai_agents.OpenAIAgentsPlugin(
1860+
model_params=ModelActivityParameters(
1861+
start_to_close_timeout=timedelta(seconds=30)
1862+
),
1863+
model_provider=CustomModelProvider(),
1864+
)
1865+
]
1866+
client = Client(**new_config)
1867+
1868+
async with new_worker(
1869+
client,
1870+
WorkflowToolWorkflow,
1871+
) as worker:
1872+
workflow_handle = await client.start_workflow(
1873+
WorkflowToolWorkflow.run,
1874+
id=f"workflow-tool-{uuid.uuid4()}",
1875+
task_queue=worker.task_queue,
1876+
execution_timeout=timedelta(seconds=10),
1877+
)
1878+
await workflow_handle.result()

0 commit comments

Comments
 (0)