|
8 | 8 | from typing import Any, AsyncIterator, Optional, Union, no_type_check |
9 | 9 |
|
10 | 10 | import nexusrpc |
| 11 | +import pydantic |
11 | 12 | import pytest |
12 | 13 | from agents import ( |
13 | 14 | Agent, |
@@ -2184,3 +2185,59 @@ def provide( |
2184 | 2185 | async for e in workflow_handle.fetch_history_events(): |
2185 | 2186 | if e.HasField("activity_task_scheduled_event_attributes"): |
2186 | 2187 | assert e.user_metadata.summary.data == b'"My summary"' |
| 2188 | + |
| 2189 | + |
| 2190 | +class OutputType(pydantic.BaseModel): |
| 2191 | + answer: str |
| 2192 | + model_config = ConfigDict(extra="forbid") # Forbid additional properties |
| 2193 | + |
| 2194 | + |
| 2195 | +@workflow.defn |
| 2196 | +class OutputTypeWorkflow: |
| 2197 | + @workflow.run |
| 2198 | + async def run(self) -> OutputType: |
| 2199 | + agent: Agent = Agent( |
| 2200 | + name="Assistant", |
| 2201 | + instructions="You are a helpful assistant, adhere to the json schema output", |
| 2202 | + output_type=OutputType, |
| 2203 | + ) |
| 2204 | + result = await Runner.run( |
| 2205 | + starting_agent=agent, |
| 2206 | + input="Hello!", |
| 2207 | + ) |
| 2208 | + return result.final_output |
| 2209 | + |
| 2210 | + |
| 2211 | +class OutputTypeModel(StaticTestModel): |
| 2212 | + responses = [ |
| 2213 | + ResponseBuilders.output_message( |
| 2214 | + '{"answer": "My answer"}', |
| 2215 | + ), |
| 2216 | + ] |
| 2217 | + |
| 2218 | + |
| 2219 | +async def test_output_type(client: Client): |
| 2220 | + new_config = client.config() |
| 2221 | + new_config["plugins"] = [ |
| 2222 | + openai_agents.OpenAIAgentsPlugin( |
| 2223 | + model_params=ModelActivityParameters( |
| 2224 | + start_to_close_timeout=timedelta(seconds=120), |
| 2225 | + ), |
| 2226 | + model_provider=TestModelProvider(OutputTypeModel()), |
| 2227 | + ) |
| 2228 | + ] |
| 2229 | + client = Client(**new_config) |
| 2230 | + |
| 2231 | + async with new_worker( |
| 2232 | + client, |
| 2233 | + OutputTypeWorkflow, |
| 2234 | + ) as worker: |
| 2235 | + workflow_handle = await client.start_workflow( |
| 2236 | + OutputTypeWorkflow.run, |
| 2237 | + id=f"output-type-{uuid.uuid4()}", |
| 2238 | + task_queue=worker.task_queue, |
| 2239 | + execution_timeout=timedelta(seconds=10), |
| 2240 | + ) |
| 2241 | + result = await workflow_handle.result() |
| 2242 | + assert isinstance(result, OutputType) |
| 2243 | + assert result.answer == "My answer" |
0 commit comments