diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 7c3df0897..52e626c6c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -8,6 +8,7 @@ from typing import Any, AsyncIterator, Optional, Union, no_type_check import nexusrpc +import pydantic import pytest from agents import ( Agent, @@ -2184,3 +2185,59 @@ def provide( async for e in workflow_handle.fetch_history_events(): if e.HasField("activity_task_scheduled_event_attributes"): assert e.user_metadata.summary.data == b'"My summary"' + + +class OutputType(pydantic.BaseModel): + answer: str + model_config = ConfigDict(extra="forbid") # Forbid additional properties + + +@workflow.defn +class OutputTypeWorkflow: + @workflow.run + async def run(self) -> OutputType: + agent: Agent = Agent( + name="Assistant", + instructions="You are a helpful assistant, adhere to the json schema output", + output_type=OutputType, + ) + result = await Runner.run( + starting_agent=agent, + input="Hello!", + ) + return result.final_output + + +class OutputTypeModel(StaticTestModel): + responses = [ + ResponseBuilders.output_message( + '{"answer": "My answer"}', + ), + ] + + +async def test_output_type(client: Client): + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=120), + ), + model_provider=TestModelProvider(OutputTypeModel()), + ) + ] + client = Client(**new_config) + + async with new_worker( + client, + OutputTypeWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + OutputTypeWorkflow.run, + id=f"output-type-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + assert isinstance(result, OutputType) + assert result.answer == "My answer"