Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
"""

from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents.temporal_openai_agents import (
from temporalio.contrib.openai_agents._temporal_openai_agents import (
OpenAIAgentsPlugin,
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)

from . import workflow

Expand Down
76 changes: 63 additions & 13 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import enum
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union, cast

from agents import (
Expand All @@ -17,17 +18,25 @@
ModelResponse,
ModelSettings,
ModelTracing,
OpenAIProvider,
RunContextWrapper,
Tool,
TResponseInputItem,
UserError,
WebSearchTool,
)
from agents.models.multi_provider import MultiProvider
from openai import (
APIStatusError,
AsyncOpenAI,
AuthenticationError,
PermissionDeniedError,
)
from typing_extensions import Required, TypedDict

from temporalio import activity
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
from temporalio.exceptions import ApplicationError


@dataclass
Expand Down Expand Up @@ -117,11 +126,15 @@ class ActivityModelInput(TypedDict, total=False):


class ModelActivity:
"""Class wrapper for model invocation activities to allow model customization."""
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
"""

def __init__(self, model_provider: Optional[ModelProvider] = None):
"""Initialize the activity with a model provider."""
self._model_provider = model_provider or MultiProvider()
self._model_provider = model_provider or OpenAIProvider(
openai_client=AsyncOpenAI(max_retries=0)
)

@activity.defn
@_auto_heartbeater
Expand Down Expand Up @@ -171,14 +184,51 @@ def make_tool(tool: ToolInput) -> Tool:
)
for x in input.get("handoffs", [])
]
return await model.get_response(
system_instructions=input.get("system_instructions"),
input=input_input,
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
prompt=input.get("prompt"),
)

try:
return await model.get_response(
system_instructions=input.get("system_instructions"),
input=input_input,
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
prompt=input.get("prompt"),
)
except APIStatusError as e:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if e.response.status_code in [408, 409, 429, 500]:
raise ApplicationError(
"Retryable OpenAI status code",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
"Non retryable OpenAI status code",
non_retryable=True,
next_retry_delay=retry_after,
) from e
4 changes: 2 additions & 2 deletions temporalio/contrib/openai_agents/_model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class ModelActivityParameters:
task_queue: Optional[str] = None
"""Specific task queue to use for model activities."""

schedule_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
schedule_to_close_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to completion."""

schedule_to_start_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to starting."""

start_to_close_timeout: Optional[timedelta] = None
start_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
"""Maximum time for the activity to complete."""

heartbeat_timeout: Optional[timedelta] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Initialize Temporal OpenAI Agents overrides."""

from contextlib import contextmanager
from datetime import timedelta
from typing import AsyncIterator, Callable, Optional, Union

from agents import (
Expand Down Expand Up @@ -39,7 +40,7 @@

@contextmanager
def set_open_ai_agent_temporal_overrides(
model_params: Optional[ModelActivityParameters] = None,
model_params: ModelActivityParameters,
auto_close_tracing_in_workflows: bool = False,
):
"""Configure Temporal-specific overrides for OpenAI agents.
Expand Down Expand Up @@ -69,14 +70,6 @@ def set_open_ai_agent_temporal_overrides(
if model_params is None:
model_params = ModelActivityParameters()

if (
not model_params.start_to_close_timeout
and not model_params.schedule_to_close_timeout
):
raise ValueError(
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
)

previous_runner = get_default_agent_runner()
previous_trace_provider = get_trace_provider()
provider = TemporalTraceProvider(
Expand Down Expand Up @@ -208,6 +201,22 @@ def __init__(
model_provider: Optional model provider for custom model implementations.
Useful for testing or custom model integrations.
"""
if model_params is None:
model_params = ModelActivityParameters()

# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
# Other providers will need to define their own.
if (
model_params.start_to_close_timeout is None
and model_params.schedule_to_close_timeout is None
):
if model_provider is None:
model_params.start_to_close_timeout = timedelta(seconds=60)
else:
raise ValueError(
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
)

self._model_params = model_params
self._model_provider = model_provider

Expand Down
66 changes: 64 additions & 2 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ToolCallItem,
ToolCallOutputItem,
)
from openai import AsyncOpenAI, BaseModel
from openai import APIStatusError, AsyncOpenAI, BaseModel
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
Expand All @@ -48,16 +48,18 @@
from openai.types.responses.response_prompt_param import ResponsePromptParam
from pydantic import ConfigDict, Field, TypeAdapter

import temporalio.api.cloud.namespace.v1
from temporalio import activity, workflow
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
from temporalio.common import RetryPolicy, SearchAttributeValueType
from temporalio.contrib import openai_agents
from temporalio.contrib.openai_agents import (
ModelActivityParameters,
TestModel,
TestModelProvider,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import CancelledError
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.testing import WorkflowEnvironment
from tests.contrib.openai_agents.research_agents.research_manager import (
ResearchManager,
Expand Down Expand Up @@ -1777,3 +1779,63 @@ async def test_response_serialization():
response_id="",
)
encoded = await pydantic_data_converter.encode([model_response])


async def assert_status_retry_behavior(status: int, client: Client, should_retry: bool):
def status_error(status: int):
with workflow.unsafe.imports_passed_through():
with workflow.unsafe.sandbox_unrestricted():
import httpx
raise APIStatusError(
message="Something went wrong.",
response=httpx.Response(
status_code=status, request=httpx.Request("GET", url="")
),
body=None,
)

new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
retry_policy=RetryPolicy(maximum_attempts=2),
),
model_provider=TestModelProvider(TestModel(lambda: status_error(status))),
)
]
client = Client(**new_config)

async with new_worker(
client,
HelloWorldAgent,
) as worker:
workflow_handle = await client.start_workflow(
HelloWorldAgent.run,
"Input",
id=f"workflow-tool-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
with pytest.raises(WorkflowFailureError) as e:
await workflow_handle.result()

found = False
async for event in workflow_handle.fetch_history_events():
if event.HasField("activity_task_started_event_attributes"):
found = True
if should_retry:
assert event.activity_task_started_event_attributes.attempt == 2
else:
assert event.activity_task_started_event_attributes.attempt == 1
assert found


async def test_exception_handling(client: Client):
await assert_status_retry_behavior(408, client, should_retry=True)
await assert_status_retry_behavior(409, client, should_retry=True)
await assert_status_retry_behavior(429, client, should_retry=True)
await assert_status_retry_behavior(500, client, should_retry=True)

await assert_status_retry_behavior(400, client, should_retry=False)
await assert_status_retry_behavior(403, client, should_retry=False)
await assert_status_retry_behavior(404, client, should_retry=False)
5 changes: 3 additions & 2 deletions tests/contrib/openai_agents/test_openai_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest

from temporalio.client import WorkflowHistory
from temporalio.contrib.openai_agents.temporal_openai_agents import (
from temporalio.contrib.openai_agents import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_openai_agents import (
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.pydantic import pydantic_data_converter
Expand Down Expand Up @@ -35,7 +36,7 @@ async def test_replay(file_name: str) -> None:
with (Path(__file__).with_name("histories") / file_name).open("r") as f:
history_json = f.read()

with set_open_ai_agent_temporal_overrides():
with set_open_ai_agent_temporal_overrides(ModelActivityParameters()):
await Replayer(
workflows=[
ResearchWorkflow,
Expand Down
Loading