diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index c7b89206c..1e178f015 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -552,6 +552,10 @@ def _create_workflow_instance( priority=temporalio.common.Priority._from_proto(init.priority), ) + last_failure = ( + init.continued_failure if init.HasField("continued_failure") else None + ) + # Create instance from details det = WorkflowInstanceDetails( payload_converter_class=self._data_converter.payload_converter_class, @@ -563,6 +567,8 @@ def _create_workflow_instance( extern_functions=self._extern_functions, disable_eager_activity_execution=self._disable_eager_activity_execution, worker_level_failure_exception_types=self._workflow_failure_exception_types, + last_completion_result=init.last_completion_result, + last_failure=last_failure, ) if defn.sandboxed: return self._workflow_runner.create_instance(det) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f0984cc84..399531af5 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -64,6 +64,7 @@ import temporalio.workflow from temporalio.service import __version__ +from ..api.failure.v1.message_pb2 import Failure from ._interceptor import ( ContinueAsNewInput, ExecuteWorkflowInput, @@ -143,6 +144,8 @@ class WorkflowInstanceDetails: extern_functions: Mapping[str, Callable] disable_eager_activity_execution: bool worker_level_failure_exception_types: Sequence[Type[BaseException]] + last_completion_result: temporalio.api.common.v1.Payloads + last_failure: Optional[Failure] class WorkflowInstance(ABC): @@ -320,6 +323,9 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: # metadata query self._current_details = "" + self._last_completion_result = det.last_completion_result + self._last_failure = det.last_failure + # The versioning behavior of this workflow, as established by annotation or by the dynamic # config function. Is only set once upon initialization. self._versioning_behavior: Optional[temporalio.common.VersioningBehavior] = None @@ -1703,6 +1709,37 @@ def workflow_is_failure_exception(self, err: BaseException) -> bool: ) ) + def workflow_has_last_completion_result(self) -> bool: + return len(self._last_completion_result.payloads) > 0 + + def workflow_last_completion_result( + self, type_hint: Optional[Type] + ) -> Optional[Any]: + if len(self._last_completion_result.payloads) == 0: + return None + elif len(self._last_completion_result.payloads) > 1: + warnings.warn( + f"Expected single last completion result, got {len(self._last_completion_result.payloads)}" + ) + return None + + if type_hint is None: + return self._payload_converter.from_payload( + self._last_completion_result.payloads[0] + ) + else: + return self._payload_converter.from_payload( + self._last_completion_result.payloads[0], type_hint + ) + + def workflow_last_failure(self) -> Optional[BaseException]: + if self._last_failure: + return self._failure_converter.from_failure( + self._last_failure, self._payload_converter + ) + + return None + #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". @@ -2766,6 +2803,7 @@ def _apply_schedule_command( v.start_to_close_timeout.FromTimedelta(self._input.start_to_close_timeout) if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) + v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index ba1a7f3ce..c656e3041 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -18,6 +18,9 @@ import temporalio.worker._workflow_instance import temporalio.workflow +from ...api.common.v1.message_pb2 import Payloads +from ...api.failure.v1.message_pb2 import Failure + # Workflow instance has to be relative import from .._workflow_instance import ( UnsandboxedWorkflowRunner, @@ -84,6 +87,8 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None: extern_functions={}, disable_eager_activity_execution=False, worker_level_failure_exception_types=self._worker_level_failure_exception_types, + last_completion_result=Payloads(), + last_failure=Failure(), ), ) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 50118a2bb..4e374e0a9 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -61,6 +61,7 @@ import temporalio.workflow from temporalio.nexus._util import ServiceHandlerT +from .api.failure.v1.message_pb2 import Failure from .types import ( AnyType, CallableAsyncNoParam, @@ -900,6 +901,17 @@ def workflow_set_current_details(self, details: str): ... @abstractmethod def workflow_is_failure_exception(self, err: BaseException) -> bool: ... + @abstractmethod + def workflow_has_last_completion_result(self) -> bool: ... + + @abstractmethod + def workflow_last_completion_result( + self, type_hint: Optional[Type] + ) -> Optional[Any]: ... + + @abstractmethod + def workflow_last_failure(self) -> Optional[BaseException]: ... + _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( "__temporal_current_update_info" @@ -1051,6 +1063,32 @@ def get_current_details() -> str: return _Runtime.current().workflow_get_current_details() +def has_last_completion_result() -> bool: + """Gets whether there is a last completion result of the workflow.""" + return _Runtime.current().workflow_has_last_completion_result() + + +@overload +def get_last_completion_result() -> Optional[Any]: ... + + +@overload +def get_last_completion_result(type_hint: Type[ParamType]) -> Optional[ParamType]: ... + + +def get_last_completion_result(type_hint: Optional[Type] = None) -> Optional[Any]: + """Get the result of the last run of the workflow. This will be None if there was + no previous completion or the result was None. has_last_completion_result() + can be used to differentiate. + """ + return _Runtime.current().workflow_last_completion_result(type_hint) + + +def get_last_failure() -> Optional[BaseException]: + """Get the last failure of the workflow if it has run previously.""" + return _Runtime.current().workflow_last_failure() + + def set_current_details(description: str) -> None: """Set the current details of the workflow which may appear in the UI/CLI. Unlike static details set at start, this value can be updated throughout diff --git a/tests/test_client.py b/tests/test_client.py index 9c33e9e1c..5671bc118 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,10 @@ +import asyncio import dataclasses import json import os import uuid from datetime import datetime, timedelta, timezone -from typing import Any, List, Mapping, Optional, cast +from typing import Any, List, Mapping, Optional, Tuple, cast from unittest import mock import google.protobuf.any_pb2 @@ -91,6 +92,7 @@ from temporalio.testing import WorkflowEnvironment from tests.helpers import ( assert_eq_eventually, + assert_eventually, ensure_search_attributes_present, new_worker, worker_versioning_enabled, @@ -1501,3 +1503,58 @@ async def test_cloud_client_simple(): GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"]) ) assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace + + +@workflow.defn +class LastCompletionResultWorkflow: + @workflow.run + async def run(self) -> str: + last_result = workflow.get_last_completion_result(type_hint=str) + if last_result is not None: + return "From last completion: " + last_result + else: + return "My First Result" + + +async def test_schedule_last_completion_result( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") + + async with new_worker(client, LastCompletionResultWorkflow) as worker: + handle = await client.create_schedule( + f"schedule-{uuid.uuid4()}", + Schedule( + action=ScheduleActionStartWorkflow( + "LastCompletionResultWorkflow", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ), + spec=ScheduleSpec(), + ), + ) + await handle.trigger() + + async def get_schedule_result() -> Tuple[int, Optional[str]]: + desc = await handle.describe() + length = len(desc.info.recent_actions) + if length == 0: + return length, None + else: + workflow_id = cast( + ScheduleActionExecutionStartWorkflow, + desc.info.recent_actions[-1].action, + ).workflow_id + workflow_handle = client.get_workflow_handle(workflow_id) + result = await workflow_handle.result() + return length, result + + assert await get_schedule_result() == (1, "My First Result") + await handle.trigger() + assert await get_schedule_result() == ( + 2, + "From last completion: My First Result", + ) + + await handle.delete() diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index e97bf3e02..87729528c 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8327,3 +8327,29 @@ async def test_workflow_headers_with_codec( assert headers["foo"].data == b"bar" else: assert headers["foo"].data != b"bar" + + +@workflow.defn +class PreviousRunFailureWorkflow: + @workflow.run + async def run(self) -> str: + if workflow.info().attempt != 1: + previous_failure = workflow.get_last_failure() + assert isinstance(previous_failure, ApplicationError) + assert previous_failure.message == "Intentional Failure" + return "Done" + raise ApplicationError("Intentional Failure") + + +async def test_previous_run_failure(client: Client): + async with new_worker(client, PreviousRunFailureWorkflow) as worker: + handle = await client.start_workflow( + PreviousRunFailureWorkflow.run, + id=f"previous-run-failure-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + retry_policy=RetryPolicy( + initial_interval=timedelta(milliseconds=10), + ), + ) + result = await handle.result() + assert result == "Done"