diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a9637f8ad..6289dbcd0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,6 +55,8 @@ jobs: with: submodules: recursive - uses: dtolnay/rust-toolchain@stable + with: + components: "clippy" - uses: Swatinem/rust-cache@v2 with: workspaces: temporalio/bridge -> target diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index de97987d3..0b120a2a6 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -1,5 +1,4 @@ import dataclasses -import json import typing from typing import Any, Optional, Union @@ -17,7 +16,6 @@ TResponseInputItem, ) from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner -from pydantic_core import to_json from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters @@ -25,6 +23,55 @@ from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError +# Recursively replace models in all agents +def _convert_agent( + model_params: ModelActivityParameters, + agent: Agent[Any], + seen: Optional[dict[int, Agent]], +) -> Agent[Any]: + if seen is None: + seen = dict() + + # Short circuit if this model was already seen to prevent looping from circular handoffs + if id(agent) in seen: + return seen[id(agent)] + + # This agent has already been processed in some other run + if isinstance(agent.model, _TemporalModelStub): + return agent + + # Save the new version of the agent so that we can replace loops + new_agent = dataclasses.replace(agent) + seen[id(agent)] = new_agent + + name = _model_name(agent) + + new_handoffs: list[Union[Agent, Handoff]] = [] + for handoff in agent.handoffs: + if isinstance(handoff, Agent): + new_handoffs.append(_convert_agent(model_params, handoff, seen)) + elif isinstance(handoff, Handoff): + original_invoke = handoff.on_invoke_handoff + + async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: + handoff_agent = await original_invoke(context, args) + return _convert_agent(model_params, handoff_agent, seen) + + new_handoffs.append( + dataclasses.replace(handoff, on_invoke_handoff=on_invoke) + ) + else: + raise TypeError(f"Unknown handoff type: {type(handoff)}") + + new_agent.model = _TemporalModelStub( + model_name=name, + model_params=model_params, + agent=agent, + ) + new_agent.handoffs = new_handoffs + return new_agent + + class TemporalOpenAIRunner(AgentRunner): """Temporal Runner for OpenAI agents. @@ -101,54 +148,9 @@ async def run( ), ) - # Recursively replace models in all agents - def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]: - if seen is None: - seen = set() - - # Short circuit if this model was already seen to prevent looping from circular handoffs - if id(agent) in seen: - return agent - seen.add(id(agent)) - - # This agent has already been processed in some other run - if isinstance(agent.model, _TemporalModelStub): - return agent - - name = _model_name(agent) - - new_handoffs: list[Union[Agent, Handoff]] = [] - for handoff in agent.handoffs: - if isinstance(handoff, Agent): - new_handoffs.append(convert_agent(handoff, seen)) - elif isinstance(handoff, Handoff): - original_invoke = handoff.on_invoke_handoff - - async def on_invoke( - context: RunContextWrapper[Any], args: str - ) -> Agent: - handoff_agent = await original_invoke(context, args) - return convert_agent(handoff_agent, seen) - - new_handoffs.append( - dataclasses.replace(handoff, on_invoke_handoff=on_invoke) - ) - else: - raise ValueError(f"Unknown handoff type: {type(handoff)}") - - return dataclasses.replace( - agent, - model=_TemporalModelStub( - model_name=name, - model_params=self.model_params, - agent=agent, - ), - handoffs=new_handoffs, - ) - try: return await self._runner.run( - starting_agent=convert_agent(starting_agent, None), + starting_agent=_convert_agent(self.model_params, starting_agent, None), input=input, context=context, max_turns=max_turns, diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 565792632..8242fa9e9 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -96,9 +96,13 @@ TestModelProvider, ) from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider -from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary +from temporalio.contrib.openai_agents._openai_runner import _convert_agent +from temporalio.contrib.openai_agents._temporal_model_stub import ( + _extract_summary, + _TemporalModelStub, +) from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.exceptions import ApplicationError, CancelledError +from temporalio.exceptions import ApplicationError, CancelledError, TemporalError from temporalio.testing import WorkflowEnvironment from temporalio.workflow import ActivityConfig from tests.contrib.openai_agents.research_agents.research_manager import ( @@ -897,7 +901,10 @@ async def update_seat( async def on_seat_booking_handoff( context: RunContextWrapper[AirlineAgentContext], ) -> None: - flight_number = f"FLT-{workflow.random().randint(100, 999)}" + try: + flight_number = f"FLT-{workflow.random().randint(100, 999)}" + except TemporalError: + flight_number = "FLT-100" context.context.flight_number = flight_number @@ -975,6 +982,8 @@ class CustomerServiceModel(StaticTestModel): ResponseBuilders.output_message( "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" ), + ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"), + ResponseBuilders.output_message("You're welcome!"), ] @@ -988,10 +997,7 @@ def __init__(self, input_items: list[TResponseInputItem] = []): @workflow.run async def run(self, input_items: list[TResponseInputItem] = []): - await workflow.wait_condition( - lambda: workflow.info().is_continue_as_new_suggested() - and workflow.all_handlers_finished() - ) + await workflow.wait_condition(lambda: False) workflow.continue_as_new(self.input_items) @workflow.query @@ -1062,7 +1068,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): ] client = Client(**new_config) - questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"] + questions = [ + "Hello", + "Book me a flight to PDX", + "11111", + "Any window seat", + "Take me back to the triage agent to say goodbye", + ] async with new_worker( client, @@ -1101,7 +1113,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): if e.HasField("activity_task_completed_event_attributes"): events.append(e) - assert len(events) == 6 + assert len(events) == 8 assert ( "Hi there! How can I assist you today?" in events[0] @@ -1138,6 +1150,18 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) + assert ( + "transfer_to_triage_agent" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "You're welcome!" + in events[7] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) class InputGuardrailModel(OpenAIResponsesModel): @@ -2571,3 +2595,17 @@ def override_get_activities() -> Sequence[Callable]: err.value.cause.message == "MCP Stateful Server Worker failed to schedule activity." ) + + +async def test_model_conversion_loops(): + agent = init_agents() + converted = _convert_agent(ModelActivityParameters(), agent, None) + seat_booking_handoff = converted.handoffs[1] + assert isinstance(seat_booking_handoff, Handoff) + context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper( + context=AirlineAgentContext() # type: ignore + ) + seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "") + triage_agent = seat_booking_agent.handoffs[0] + assert isinstance(triage_agent, Agent) + assert isinstance(triage_agent.model, _TemporalModelStub)