Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ jobs:
with:
submodules: recursive
- uses: dtolnay/rust-toolchain@stable
with:
components: "clippy"
- uses: Swatinem/rust-cache@v2
with:
workspaces: temporalio/bridge -> target
Expand Down
98 changes: 50 additions & 48 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import json
import typing
from typing import Any, Optional, Union

Expand All @@ -17,14 +16,62 @@
TResponseInputItem,
)
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
from pydantic_core import to_json

from temporalio import workflow
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError


# Recursively replace models in all agents
def _convert_agent(
model_params: ModelActivityParameters,
agent: Agent[Any],
seen: Optional[dict[int, Agent]],
) -> Agent[Any]:
if seen is None:
seen = dict()

# Short circuit if this model was already seen to prevent looping from circular handoffs
if id(agent) in seen:
return seen[id(agent)]

# This agent has already been processed in some other run
if isinstance(agent.model, _TemporalModelStub):
return agent

# Save the new version of the agent so that we can replace loops
new_agent = dataclasses.replace(agent)
seen[id(agent)] = new_agent

name = _model_name(agent)

new_handoffs: list[Union[Agent, Handoff]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
new_handoffs.append(_convert_agent(model_params, handoff, seen))
elif isinstance(handoff, Handoff):
original_invoke = handoff.on_invoke_handoff

async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent:
handoff_agent = await original_invoke(context, args)
return _convert_agent(model_params, handoff_agent, seen)

new_handoffs.append(
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
)
else:
raise TypeError(f"Unknown handoff type: {type(handoff)}")

new_agent.model = _TemporalModelStub(
model_name=name,
model_params=model_params,
agent=agent,
)
new_agent.handoffs = new_handoffs
return new_agent


class TemporalOpenAIRunner(AgentRunner):
"""Temporal Runner for OpenAI agents.

Expand Down Expand Up @@ -101,54 +148,9 @@ async def run(
),
)

# Recursively replace models in all agents
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
if seen is None:
seen = set()

# Short circuit if this model was already seen to prevent looping from circular handoffs
if id(agent) in seen:
return agent
seen.add(id(agent))

# This agent has already been processed in some other run
if isinstance(agent.model, _TemporalModelStub):
return agent

name = _model_name(agent)

new_handoffs: list[Union[Agent, Handoff]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
new_handoffs.append(convert_agent(handoff, seen))
elif isinstance(handoff, Handoff):
original_invoke = handoff.on_invoke_handoff

async def on_invoke(
context: RunContextWrapper[Any], args: str
) -> Agent:
handoff_agent = await original_invoke(context, args)
return convert_agent(handoff_agent, seen)

new_handoffs.append(
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
)
else:
raise ValueError(f"Unknown handoff type: {type(handoff)}")

return dataclasses.replace(
agent,
model=_TemporalModelStub(
model_name=name,
model_params=self.model_params,
agent=agent,
),
handoffs=new_handoffs,
)

try:
return await self._runner.run(
starting_agent=convert_agent(starting_agent, None),
starting_agent=_convert_agent(self.model_params, starting_agent, None),
input=input,
context=context,
max_turns=max_turns,
Expand Down
56 changes: 47 additions & 9 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@
TestModelProvider,
)
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
from temporalio.contrib.openai_agents._temporal_model_stub import (
_extract_summary,
_TemporalModelStub,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.exceptions import ApplicationError, CancelledError, TemporalError
from temporalio.testing import WorkflowEnvironment
from temporalio.workflow import ActivityConfig
from tests.contrib.openai_agents.research_agents.research_manager import (
Expand Down Expand Up @@ -897,7 +901,10 @@ async def update_seat(
async def on_seat_booking_handoff(
context: RunContextWrapper[AirlineAgentContext],
) -> None:
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
try:
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
except TemporalError:
flight_number = "FLT-100"
context.context.flight_number = flight_number


Expand Down Expand Up @@ -975,6 +982,8 @@ class CustomerServiceModel(StaticTestModel):
ResponseBuilders.output_message(
"Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!"
),
ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"),
ResponseBuilders.output_message("You're welcome!"),
]


Expand All @@ -988,10 +997,7 @@ def __init__(self, input_items: list[TResponseInputItem] = []):

@workflow.run
async def run(self, input_items: list[TResponseInputItem] = []):
await workflow.wait_condition(
lambda: workflow.info().is_continue_as_new_suggested()
and workflow.all_handlers_finished()
)
await workflow.wait_condition(lambda: False)
workflow.continue_as_new(self.input_items)

@workflow.query
Expand Down Expand Up @@ -1062,7 +1068,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
]
client = Client(**new_config)

questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]
questions = [
"Hello",
"Book me a flight to PDX",
"11111",
"Any window seat",
"Take me back to the triage agent to say goodbye",
]

async with new_worker(
client,
Expand Down Expand Up @@ -1101,7 +1113,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
if e.HasField("activity_task_completed_event_attributes"):
events.append(e)

assert len(events) == 6
assert len(events) == 8
assert (
"Hi there! How can I assist you today?"
in events[0]
Expand Down Expand Up @@ -1138,6 +1150,18 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"transfer_to_triage_agent"
in events[6]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"You're welcome!"
in events[7]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)


class InputGuardrailModel(OpenAIResponsesModel):
Expand Down Expand Up @@ -2571,3 +2595,17 @@ def override_get_activities() -> Sequence[Callable]:
err.value.cause.message
== "MCP Stateful Server Worker failed to schedule activity."
)


async def test_model_conversion_loops():
agent = init_agents()
converted = _convert_agent(ModelActivityParameters(), agent, None)
seat_booking_handoff = converted.handoffs[1]
assert isinstance(seat_booking_handoff, Handoff)
context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper(
context=AirlineAgentContext() # type: ignore
)
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
triage_agent = seat_booking_agent.handoffs[0]
assert isinstance(triage_agent, Agent)
assert isinstance(triage_agent.model, _TemporalModelStub)