Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ async def start_workflow(
temporalio.common._warn_on_deprecated_search_attributes(
search_attributes, stack_level=stack_level
)
name, result_type_from_run_fn = (
name, result_type_from_type_hint = (
temporalio.workflow._Definition.get_name_and_result_type(workflow)
)

Expand All @@ -539,7 +539,7 @@ async def start_workflow(
static_details=static_details,
start_signal=start_signal,
start_signal_args=start_signal_args,
ret_type=result_type or result_type_from_run_fn,
ret_type=result_type or result_type_from_type_hint,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
Expand Down Expand Up @@ -1105,7 +1105,7 @@ def on_start_success(
start_workflow_operation._start_workflow_input.id,
first_execution_run_id=start_response.run_id,
result_run_id=start_response.run_id,
result_type=result_type,
result_type=start_workflow_operation._start_workflow_input.ret_type,
)
)

Expand Down Expand Up @@ -2335,17 +2335,10 @@ async def _start_update(
) -> WorkflowUpdateHandle[Any]:
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
raise ValueError("ADMITTED wait stage not supported")
update_name: str
ret_type = result_type
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
defn = update._defn
if not defn.name:
raise RuntimeError("Cannot invoke dynamic update definition")
# TODO(cretz): Check count/type of args at runtime?
update_name = defn.name
ret_type = defn.ret_type
else:
update_name = str(update)

update_name, result_type_from_type_hint = (
temporalio.workflow._UpdateDefinition.get_name_and_result_type(update)
)

return await self._client._impl.start_workflow_update(
StartWorkflowUpdateInput(
Expand All @@ -2356,7 +2349,7 @@ async def _start_update(
update=update_name,
args=temporalio.common._arg_or_args(arg, args),
headers={},
ret_type=ret_type,
ret_type=result_type or result_type_from_type_hint,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
wait_for_stage=wait_for_stage,
Expand Down Expand Up @@ -6183,6 +6176,7 @@ async def _start_workflow_update_with_start(
workflow_id=start_input.id,
workflow_run_id=start_response.run_id,
known_outcome=known_outcome,
result_type=update_input.ret_type,
)
if update_input.wait_for_stage == WorkflowUpdateStage.COMPLETED:
await handle._poll_until_outcome()
Expand Down
15 changes: 15 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import temporalio.common
import temporalio.converter
import temporalio.exceptions
import temporalio.workflow

from .types import (
AnyType,
Expand Down Expand Up @@ -1783,6 +1784,20 @@ def set_validator(self, validator: Callable[..., None]) -> None:
raise RuntimeError(f"Validator already set for update {self.name}")
object.__setattr__(self, "validator", validator)

@classmethod
def get_name_and_result_type(
cls,
name_or_update_fn: Union[str, Callable[..., Any]],
) -> Tuple[str, Optional[Type]]:
if isinstance(name_or_update_fn, temporalio.workflow.UpdateMethodMultiParam):
defn = name_or_update_fn._defn
if not defn.name:
raise RuntimeError("Cannot invoke dynamic update definition")
# TODO(cretz): Check count/type of args at runtime?
return defn.name, defn.ret_type
else:
return str(name_or_update_fn), None


# See https://mypy.readthedocs.io/en/latest/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
Expand Down
50 changes: 50 additions & 0 deletions tests/worker/test_update_with_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, Iterator
Expand Down Expand Up @@ -515,3 +516,52 @@ def test_with_start_workflow_operation_requires_conflict_policy():
id="wid-1",
task_queue="test-queue",
)


@dataclass
class DataClass1:
a: int
b: str


@dataclass
class DataClass2:
a: int
b: str


@workflow.defn
class WorkflowCanReturnDataClass:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def run(self) -> DataClass1:
await workflow.wait_condition(lambda: self.received_update)
return DataClass1(a=1, b="workflow-result")

@workflow.update
async def update(self) -> DataClass2:
self.received_update = True
return DataClass2(a=2, b="update-result")


async def test_workflow_and_update_can_return_dataclass(client: Client):
async with new_worker(client, WorkflowCanReturnDataClass) as worker:
start_op = WithStartWorkflowOperation(
WorkflowCanReturnDataClass.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

update_handle = await client.start_update_with_start_workflow(
WorkflowCanReturnDataClass.update,
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
)

assert await update_handle.result() == DataClass2(a=2, b="update-result")

wf_handle = await start_op.workflow_handle()
assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")
Loading