diff --git a/temporalio/client.py b/temporalio/client.py index 1e5a41464..7c517a3de 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -514,7 +514,7 @@ async def start_workflow( temporalio.common._warn_on_deprecated_search_attributes( search_attributes, stack_level=stack_level ) - name, result_type_from_run_fn = ( + name, result_type_from_type_hint = ( temporalio.workflow._Definition.get_name_and_result_type(workflow) ) @@ -539,7 +539,7 @@ async def start_workflow( static_details=static_details, start_signal=start_signal, start_signal_args=start_signal_args, - ret_type=result_type or result_type_from_run_fn, + ret_type=result_type or result_type_from_type_hint, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, request_eager_start=request_eager_start, @@ -1105,7 +1105,7 @@ def on_start_success( start_workflow_operation._start_workflow_input.id, first_execution_run_id=start_response.run_id, result_run_id=start_response.run_id, - result_type=result_type, + result_type=start_workflow_operation._start_workflow_input.ret_type, ) ) @@ -2335,17 +2335,10 @@ async def _start_update( ) -> WorkflowUpdateHandle[Any]: if wait_for_stage == WorkflowUpdateStage.ADMITTED: raise ValueError("ADMITTED wait stage not supported") - update_name: str - ret_type = result_type - if isinstance(update, temporalio.workflow.UpdateMethodMultiParam): - defn = update._defn - if not defn.name: - raise RuntimeError("Cannot invoke dynamic update definition") - # TODO(cretz): Check count/type of args at runtime? - update_name = defn.name - ret_type = defn.ret_type - else: - update_name = str(update) + + update_name, result_type_from_type_hint = ( + temporalio.workflow._UpdateDefinition.get_name_and_result_type(update) + ) return await self._client._impl.start_workflow_update( StartWorkflowUpdateInput( @@ -2356,7 +2349,7 @@ async def _start_update( update=update_name, args=temporalio.common._arg_or_args(arg, args), headers={}, - ret_type=ret_type, + ret_type=result_type or result_type_from_type_hint, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, wait_for_stage=wait_for_stage, @@ -6183,6 +6176,7 @@ async def _start_workflow_update_with_start( workflow_id=start_input.id, workflow_run_id=start_response.run_id, known_outcome=known_outcome, + result_type=update_input.ret_type, ) if update_input.wait_for_stage == WorkflowUpdateStage.COMPLETED: await handle._poll_until_outcome() diff --git a/temporalio/workflow.py b/temporalio/workflow.py index e78a9feea..e2d59b1a5 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -53,6 +53,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.workflow from .types import ( AnyType, @@ -1783,6 +1784,20 @@ def set_validator(self, validator: Callable[..., None]) -> None: raise RuntimeError(f"Validator already set for update {self.name}") object.__setattr__(self, "validator", validator) + @classmethod + def get_name_and_result_type( + cls, + name_or_update_fn: Union[str, Callable[..., Any]], + ) -> Tuple[str, Optional[Type]]: + if isinstance(name_or_update_fn, temporalio.workflow.UpdateMethodMultiParam): + defn = name_or_update_fn._defn + if not defn.name: + raise RuntimeError("Cannot invoke dynamic update definition") + # TODO(cretz): Check count/type of args at runtime? + return defn.name, defn.ret_type + else: + return str(name_or_update_fn), None + # See https://mypy.readthedocs.io/en/latest/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime if TYPE_CHECKING: diff --git a/tests/worker/test_update_with_start.py b/tests/worker/test_update_with_start.py index 160590e33..cf17f1239 100644 --- a/tests/worker/test_update_with_start.py +++ b/tests/worker/test_update_with_start.py @@ -2,6 +2,7 @@ import uuid from contextlib import contextmanager +from dataclasses import dataclass from datetime import timedelta from enum import Enum from typing import Any, Iterator @@ -515,3 +516,52 @@ def test_with_start_workflow_operation_requires_conflict_policy(): id="wid-1", task_queue="test-queue", ) + + +@dataclass +class DataClass1: + a: int + b: str + + +@dataclass +class DataClass2: + a: int + b: str + + +@workflow.defn +class WorkflowCanReturnDataClass: + def __init__(self) -> None: + self.received_update = False + + @workflow.run + async def run(self) -> DataClass1: + await workflow.wait_condition(lambda: self.received_update) + return DataClass1(a=1, b="workflow-result") + + @workflow.update + async def update(self) -> DataClass2: + self.received_update = True + return DataClass2(a=2, b="update-result") + + +async def test_workflow_and_update_can_return_dataclass(client: Client): + async with new_worker(client, WorkflowCanReturnDataClass) as worker: + start_op = WithStartWorkflowOperation( + WorkflowCanReturnDataClass.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + + update_handle = await client.start_update_with_start_workflow( + WorkflowCanReturnDataClass.update, + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op, + ) + + assert await update_handle.result() == DataClass2(a=2, b="update-result") + + wf_handle = await start_op.workflow_handle() + assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")