Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ jobs:
with:
submodules: recursive
- uses: dtolnay/rust-toolchain@stable
with:
components: "clippy"
- uses: Swatinem/rust-cache@v2
with:
workspaces: temporalio/bridge -> target
Expand Down
98 changes: 50 additions & 48 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import json
import typing
from typing import Any, Optional, Union

Expand All @@ -17,14 +16,62 @@
TResponseInputItem,
)
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
from pydantic_core import to_json

from temporalio import workflow
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError


# Recursively replace models in all agents
def _convert_agent(
model_params: ModelActivityParameters,
agent: Agent[Any],
seen: Optional[dict[int, Agent]],
) -> Agent[Any]:
if seen is None:
seen = dict()

# Short circuit if this model was already seen to prevent looping from circular handoffs
if id(agent) in seen:
return seen[id(agent)]

# This agent has already been processed in some other run
if isinstance(agent.model, _TemporalModelStub):
return agent

# Save the new version of the agent so that we can replace loops
new_agent = dataclasses.replace(agent)
seen[id(agent)] = new_agent

name = _model_name(agent)

new_handoffs: list[Union[Agent, Handoff]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
new_handoffs.append(_convert_agent(model_params, handoff, seen))
elif isinstance(handoff, Handoff):
original_invoke = handoff.on_invoke_handoff

async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent:
handoff_agent = await original_invoke(context, args)
return _convert_agent(model_params, handoff_agent, seen)

new_handoffs.append(
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
)
else:
raise ValueError(f"Unknown handoff type: {type(handoff)}")

new_agent.model = _TemporalModelStub(
model_name=name,
model_params=model_params,
agent=agent,
)
new_agent.handoffs = new_handoffs
return new_agent


class TemporalOpenAIRunner(AgentRunner):
"""Temporal Runner for OpenAI agents.

Expand Down Expand Up @@ -101,54 +148,9 @@ async def run(
),
)

# Recursively replace models in all agents
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
if seen is None:
seen = set()

# Short circuit if this model was already seen to prevent looping from circular handoffs
if id(agent) in seen:
return agent
seen.add(id(agent))

# This agent has already been processed in some other run
if isinstance(agent.model, _TemporalModelStub):
return agent

name = _model_name(agent)

new_handoffs: list[Union[Agent, Handoff]] = []
for handoff in agent.handoffs:
if isinstance(handoff, Agent):
new_handoffs.append(convert_agent(handoff, seen))
elif isinstance(handoff, Handoff):
original_invoke = handoff.on_invoke_handoff

async def on_invoke(
context: RunContextWrapper[Any], args: str
) -> Agent:
handoff_agent = await original_invoke(context, args)
return convert_agent(handoff_agent, seen)

new_handoffs.append(
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
)
else:
raise ValueError(f"Unknown handoff type: {type(handoff)}")

return dataclasses.replace(
agent,
model=_TemporalModelStub(
model_name=name,
model_params=self.model_params,
agent=agent,
),
handoffs=new_handoffs,
)

try:
return await self._runner.run(
starting_agent=convert_agent(starting_agent, None),
starting_agent=_convert_agent(self.model_params, starting_agent, None),
input=input,
context=context,
max_turns=max_turns,
Expand Down
54 changes: 47 additions & 7 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@
TestModelProvider,
)
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
from temporalio.contrib.openai_agents._temporal_model_stub import (
_extract_summary,
_TemporalModelStub,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.exceptions import ApplicationError, CancelledError, TemporalError
from temporalio.testing import WorkflowEnvironment
from temporalio.workflow import ActivityConfig
from tests.contrib.openai_agents.research_agents.research_manager import (
Expand Down Expand Up @@ -897,7 +901,10 @@ async def update_seat(
async def on_seat_booking_handoff(
context: RunContextWrapper[AirlineAgentContext],
) -> None:
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
try:
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
except TemporalError:
flight_number = "FLT-100"
context.context.flight_number = flight_number


Expand Down Expand Up @@ -975,6 +982,8 @@ class CustomerServiceModel(StaticTestModel):
ResponseBuilders.output_message(
"Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!"
),
ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"),
ResponseBuilders.output_message("You're welcome!"),
]


Expand All @@ -989,8 +998,7 @@ def __init__(self, input_items: list[TResponseInputItem] = []):
@workflow.run
async def run(self, input_items: list[TResponseInputItem] = []):
await workflow.wait_condition(
lambda: workflow.info().is_continue_as_new_suggested()
and workflow.all_handlers_finished()
lambda: False
)
workflow.continue_as_new(self.input_items)

Expand Down Expand Up @@ -1062,7 +1070,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
]
client = Client(**new_config)

questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]
questions = [
"Hello",
"Book me a flight to PDX",
"11111",
"Any window seat",
"Take me back to the triage agent to say goodbye",
]

async with new_worker(
client,
Expand Down Expand Up @@ -1101,7 +1115,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
if e.HasField("activity_task_completed_event_attributes"):
events.append(e)

assert len(events) == 6
assert len(events) == 8
assert (
"Hi there! How can I assist you today?"
in events[0]
Expand Down Expand Up @@ -1138,6 +1152,18 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"transfer_to_triage_agent"
in events[6]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"You're welcome!"
in events[7]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)


class InputGuardrailModel(OpenAIResponsesModel):
Expand Down Expand Up @@ -2571,3 +2597,17 @@ def override_get_activities() -> Sequence[Callable]:
err.value.cause.message
== "MCP Stateful Server Worker failed to schedule activity."
)


async def test_model_conversion_loops():
agent = init_agents()
converted = _convert_agent(ModelActivityParameters(), agent, None)
seat_booking_handoff = converted.handoffs[1]
assert isinstance(seat_booking_handoff, Handoff)
context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper(
context=AirlineAgentContext() # type: ignore
)
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
triage_agent = seat_booking_agent.handoffs[0]
assert isinstance(triage_agent, Agent)
assert isinstance(triage_agent.model, _TemporalModelStub)
Loading