Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@

from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.temporal_openai_agents import (
from temporalio.contrib.openai_agents._temporal_openai_agents import (
TestModel,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)

from . import workflow

__all__ = [
"ModelActivity",
"ModelActivityParameters",
"workflow",
"set_open_ai_agent_temporal_overrides",
"OpenAIAgentsTracingInterceptor",
"TestModel",
Expand Down
59 changes: 46 additions & 13 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@
ModelResponse,
ModelSettings,
ModelTracing,
OpenAIProvider,
RunContextWrapper,
Tool,
TResponseInputItem,
UserError,
WebSearchTool,
)
from agents.models.multi_provider import MultiProvider
from openai import (
APIStatusError,
AsyncOpenAI,
AuthenticationError,
PermissionDeniedError,
)
from typing_extensions import Required, TypedDict

from temporalio import activity
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
from temporalio.exceptions import ApplicationError


@dataclass
Expand Down Expand Up @@ -117,11 +125,15 @@ class ActivityModelInput(TypedDict, total=False):


class ModelActivity:
"""Class wrapper for model invocation activities to allow model customization."""
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
"""

def __init__(self, model_provider: Optional[ModelProvider] = None):
"""Initialize the activity with a model provider."""
self._model_provider = model_provider or MultiProvider()
self._model_provider = model_provider or OpenAIProvider(
openai_client=AsyncOpenAI(max_retries=0)
)

@activity.defn
@_auto_heartbeater
Expand Down Expand Up @@ -171,14 +183,35 @@ def make_tool(tool: ToolInput) -> Tool:
)
for x in input.get("handoffs", [])
]
return await model.get_response(
system_instructions=input.get("system_instructions"),
input=input_input,
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
prompt=input.get("prompt"),
)

try:
return await model.get_response(
system_instructions=input.get("system_instructions"),
input=input_input,
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
prompt=input.get("prompt"),
)
except APIStatusError as e:
# Listen to server hint
should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable openai error", non_retryable=True
) from e

# Specifically retryable status codes
if e.response.status_code in [408, 409, 429, 500]:
raise ApplicationError(
"Retryable openai status code", non_retryable=False
) from e

raise ApplicationError(
"Non retryable openai status code", non_retryable=True
) from e
4 changes: 2 additions & 2 deletions temporalio/contrib/openai_agents/_model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class ModelActivityParameters:
task_queue: Optional[str] = None
"""Specific task queue to use for model activities."""

schedule_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
schedule_to_close_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to completion."""

schedule_to_start_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to starting."""

start_to_close_timeout: Optional[timedelta] = None
start_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
"""Maximum time for the activity to complete."""

heartbeat_timeout: Optional[timedelta] = None
Expand Down
67 changes: 65 additions & 2 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ToolCallItem,
ToolCallOutputItem,
)
from openai import AsyncOpenAI, BaseModel
from openai import APIStatusError, AsyncOpenAI, BaseModel
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
Expand All @@ -47,8 +47,10 @@
from openai.types.responses.response_prompt_param import ResponsePromptParam
from pydantic import ConfigDict, Field

import temporalio.api.cloud.namespace.v1
from temporalio import activity, workflow
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
from temporalio.common import RetryPolicy, SearchAttributeValueType
from temporalio.contrib import openai_agents
from temporalio.contrib.openai_agents import (
ModelActivity,
Expand All @@ -59,7 +61,7 @@
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import CancelledError
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.testing import WorkflowEnvironment
from tests.contrib.openai_agents.research_agents.research_manager import (
ResearchManager,
Expand Down Expand Up @@ -1778,3 +1780,64 @@ async def test_workflow_method_tools(client: Client):
execution_timeout=timedelta(seconds=10),
)
await workflow_handle.result()


async def assert_status_retry_behavior(status: int, client: Client, should_retry: bool):
with workflow.unsafe.sandbox_unrestricted():
import httpx

def status_error(status: int):
raise APIStatusError(
message="Something went wrong.",
response=httpx.Response(
status_code=status, request=httpx.Request("GET", url="")
),
body=None,
)

# Test error 500 retries
model_activity = ModelActivity(
TestModelProvider(TestModel(lambda: status_error(status)))
)
async with new_worker(
client,
HelloWorldAgent,
activities=[model_activity.invoke_model_activity],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
HelloWorldAgent.run,
"Input",
id=f"workflow-tool-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=10),
)
with pytest.raises(WorkflowFailureError) as e:
await workflow_handle.result()

async for event in workflow_handle.fetch_history_events():
if event.HasField("activity_task_started_event_attributes"):
if should_retry:
assert event.activity_task_started_event_attributes.attempt == 2
else:
assert event.activity_task_started_event_attributes.attempt == 1


async def test_exception_handling(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides(
model_params=ModelActivityParameters(
retry_policy=RetryPolicy(maximum_attempts=2)
)
):
await assert_status_retry_behavior(408, client, should_retry=True)
await assert_status_retry_behavior(409, client, should_retry=True)
await assert_status_retry_behavior(429, client, should_retry=True)
await assert_status_retry_behavior(500, client, should_retry=True)

await assert_status_retry_behavior(400, client, should_retry=False)
await assert_status_retry_behavior(403, client, should_retry=False)
await assert_status_retry_behavior(404, client, should_retry=False)
2 changes: 1 addition & 1 deletion tests/contrib/openai_agents/test_openai_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from temporalio.client import WorkflowHistory
from temporalio.contrib.openai_agents.temporal_openai_agents import (
from temporalio.contrib.openai_agents._temporal_openai_agents import (
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.pydantic import pydantic_data_converter
Expand Down