Skip to content

Commit 0fd5c7e

Browse files
committed
Reraise workflow failure errors from OpenAI's UserError
1 parent 2affa25 commit 0fd5c7e

File tree

6 files changed

+105
-16
lines changed

6 files changed

+105
-16
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
12+
from temporalio.contrib.openai_agents._openai_runner import AgentsWorkflowFailure
1213
from temporalio.contrib.openai_agents._temporal_openai_agents import (
1314
OpenAIAgentsPlugin,
1415
TestModel,
@@ -21,6 +22,7 @@
2122
from . import workflow
2223

2324
__all__ = [
25+
"AgentsWorkflowFailure",
2426
"OpenAIAgentsPlugin",
2527
"ModelActivityParameters",
2628
"workflow",

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from agents import (
77
Agent,
8+
AgentsException,
89
Handoff,
910
RunConfig,
1011
RunContextWrapper,
@@ -21,6 +22,16 @@
2122
from temporalio import workflow
2223
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
2324
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
25+
from temporalio.exceptions import ApplicationError, TemporalError
26+
27+
28+
class AgentsWorkflowFailure(TemporalError):
29+
"""Error that occurs when the agents SDK raises an error which should terminate the calling workflow.
30+
31+
.. warning::
32+
This exception is experimental and may change in future versions.
33+
Use with caution in production environments.
34+
"""
2435

2536

2637
class TemporalOpenAIRunner(AgentRunner):
@@ -136,16 +147,28 @@ async def on_invoke(
136147
handoffs=new_handoffs,
137148
)
138149

139-
return await self._runner.run(
140-
starting_agent=convert_agent(starting_agent, None),
141-
input=input,
142-
context=context,
143-
max_turns=max_turns,
144-
hooks=hooks,
145-
run_config=run_config,
146-
previous_response_id=previous_response_id,
147-
session=session,
148-
)
150+
try:
151+
return await self._runner.run(
152+
starting_agent=convert_agent(starting_agent, None),
153+
input=input,
154+
context=context,
155+
max_turns=max_turns,
156+
hooks=hooks,
157+
run_config=run_config,
158+
previous_response_id=previous_response_id,
159+
session=session,
160+
)
161+
except AgentsException as e:
162+
# In order for workflow failures to properly fail the workflow, we need to rewrap them in
163+
# a Temporal error
164+
if e.__cause__ and workflow.is_workflow_failure_exception(e.__cause__):
165+
reraise = AgentsWorkflowFailure(
166+
f"Workflow failure exception in Agents Framework: {e}"
167+
)
168+
reraise.__traceback__ = e.__traceback__
169+
raise reraise from e.__cause__
170+
else:
171+
raise e
149172

150173
def run_sync(
151174
self,

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
from temporalio.client import ClientConfig, Plugin
2828
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
2929
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
30-
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
30+
from temporalio.contrib.openai_agents._openai_runner import (
31+
AgentsWorkflowFailure,
32+
TemporalOpenAIRunner,
33+
)
3134
from temporalio.contrib.openai_agents._temporal_trace_provider import (
3235
TemporalTraceProvider,
3336
)
@@ -284,6 +287,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
284287
config["activities"] = list(config.get("activities") or []) + [
285288
ModelActivity(self._model_provider).invoke_model_activity
286289
]
290+
config["workflow_failure_exception_types"] = list(
291+
config.get("workflow_failure_exception_types") or []
292+
) + [AgentsWorkflowFailure]
287293
return self.next_worker_plugin.configure_worker(config)
288294

289295
async def run_worker(self, worker: Worker) -> None:

temporalio/worker/_workflow_instance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def activate(
414414
# We want some errors during activation, like those that can happen
415415
# during payload conversion, to be able to fail the workflow not the
416416
# task
417-
if self._is_workflow_failure_exception(err):
417+
if self.is_workflow_failure_exception(err):
418418
try:
419419
self._set_workflow_failure(err)
420420
except Exception as inner_err:
@@ -629,7 +629,7 @@ async def run_update() -> None:
629629
# Validation failures are always update failures. We reuse
630630
# workflow failure logic to decide task failure vs update
631631
# failure after validation.
632-
if not past_validation or self._is_workflow_failure_exception(err):
632+
if not past_validation or self.is_workflow_failure_exception(err):
633633
if command is None:
634634
command = self._add_command()
635635
command.update_response.protocol_instance_id = (
@@ -1939,7 +1939,7 @@ def _convert_payloads(
19391939
# Don't wrap payload conversion errors that would fail the workflow
19401940
raise
19411941
except Exception as err:
1942-
if self._is_workflow_failure_exception(err):
1942+
if self.is_workflow_failure_exception(err):
19431943
raise
19441944
raise RuntimeError("Failed decoding arguments") from err
19451945

@@ -1982,7 +1982,7 @@ def _instantiate_workflow_object(self) -> Any:
19821982

19831983
return workflow_instance
19841984

1985-
def _is_workflow_failure_exception(self, err: BaseException) -> bool:
1985+
def is_workflow_failure_exception(self, err: BaseException) -> bool:
19861986
# An exception is a failure instead of a task fail if it's already a
19871987
# failure error or if it is a timeout error or if it is an instance of
19881988
# any of the failure types in the worker or workflow-level setting
@@ -2192,7 +2192,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
21922192
err
21932193
):
21942194
self._add_command().cancel_workflow_execution.SetInParent()
2195-
elif self._is_workflow_failure_exception(err):
2195+
elif self.is_workflow_failure_exception(err):
21962196
# All other failure errors fail the workflow
21972197
self._set_workflow_failure(err)
21982198
else:

temporalio/workflow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,9 @@ def workflow_get_current_details(self) -> str: ...
897897
@abstractmethod
898898
def workflow_set_current_details(self, details: str): ...
899899

900+
@abstractmethod
901+
def is_workflow_failure_exception(self, err: BaseException) -> bool: ...
902+
900903

901904
_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
902905
"__temporal_current_update_info"
@@ -981,6 +984,10 @@ def memo() -> Mapping[str, Any]:
981984
return _Runtime.current().workflow_memo()
982985

983986

987+
def is_workflow_failure_exception(err: BaseException) -> bool:
988+
return _Runtime.current().is_workflow_failure_exception(err)
989+
990+
984991
@overload
985992
def memo_value(key: str, default: Any = temporalio.common._arg_unset) -> Any: ...
986993

tests/contrib/openai_agents/test_openai.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ async def run(self, question: str) -> str:
318318
ActivityWeatherService.get_weather_method,
319319
start_to_close_timeout=timedelta(seconds=10),
320320
),
321+
openai_agents.workflow.activity_as_tool(
322+
get_weather_failure,
323+
start_to_close_timeout=timedelta(seconds=10),
324+
),
321325
],
322326
)
323327
result = await Runner.run(
@@ -462,6 +466,53 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
462466
)
463467

464468

469+
@activity.defn
470+
async def get_weather_failure(city: str) -> Weather:
471+
"""
472+
Get the weather for a given city.
473+
"""
474+
raise ApplicationError("No weather", non_retryable=True)
475+
476+
477+
class TestWeatherFailureModel(StaticTestModel):
478+
responses = [
479+
ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"),
480+
]
481+
482+
483+
async def test_tool_failure_workflow(client: Client):
484+
new_config = client.config()
485+
new_config["plugins"] = [
486+
openai_agents.OpenAIAgentsPlugin(
487+
model_params=ModelActivityParameters(
488+
start_to_close_timeout=timedelta(seconds=30)
489+
),
490+
model_provider=TestModelProvider(TestWeatherFailureModel()),
491+
)
492+
]
493+
client = Client(**new_config)
494+
495+
async with new_worker(
496+
client,
497+
ToolsWorkflow,
498+
activities=[
499+
get_weather_failure,
500+
],
501+
) as worker:
502+
workflow_handle = await client.start_workflow(
503+
ToolsWorkflow.run,
504+
"What is the weather in Tokio?",
505+
id=f"tools-failure-workflow-{uuid.uuid4()}",
506+
task_queue=worker.task_queue,
507+
execution_timeout=timedelta(seconds=2),
508+
)
509+
with pytest.raises(WorkflowFailureError) as e:
510+
result = await workflow_handle.result()
511+
cause = e.value.cause
512+
assert isinstance(cause, ApplicationError)
513+
assert "Workflow failure exception in Agents Framework" in cause.message
514+
515+
465516
@pytest.mark.parametrize("use_local_model", [True, False])
466517
async def test_nexus_tool_workflow(
467518
client: Client, env: WorkflowEnvironment, use_local_model: bool

0 commit comments

Comments
 (0)