Skip to content

Commit 1a19834

Browse files
committed
Add test for chat completions model
1 parent 126bcd8 commit 1a19834

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

tests/contrib/openai_agents/test_openai.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
handoff,
2929
input_guardrail,
3030
output_guardrail,
31-
trace,
31+
trace, ModelProvider, Model, OpenAIChatCompletionsModel,
3232
)
3333
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
3434
from agents.items import (
@@ -1778,3 +1778,32 @@ async def test_workflow_method_tools(client: Client):
17781778
execution_timeout=timedelta(seconds=10),
17791779
)
17801780
await workflow_handle.result()
1781+
1782+
class CustomModelProvider(ModelProvider):
1783+
def get_model(self, model_name: str) -> Model:
1784+
client = AsyncOpenAI(base_url="https://api.openai.com/v1")
1785+
return OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client)
1786+
1787+
async def test_chat_completions_model(client: Client):
1788+
if not os.environ.get("OPENAI_API_KEY"):
1789+
pytest.skip("No openai API key")
1790+
1791+
new_config = client.config()
1792+
new_config["data_converter"] = pydantic_data_converter
1793+
client = Client(**new_config)
1794+
1795+
with set_open_ai_agent_temporal_overrides():
1796+
model_activity = ModelActivity(model_provider=CustomModelProvider())
1797+
async with new_worker(
1798+
client,
1799+
WorkflowToolWorkflow,
1800+
activities=[model_activity.invoke_model_activity],
1801+
interceptors=[OpenAIAgentsTracingInterceptor()],
1802+
) as worker:
1803+
workflow_handle = await client.start_workflow(
1804+
WorkflowToolWorkflow.run,
1805+
id=f"workflow-tool-{uuid.uuid4()}",
1806+
task_queue=worker.task_queue,
1807+
execution_timeout=timedelta(seconds=10),
1808+
)
1809+
await workflow_handle.result()

0 commit comments

Comments
 (0)