diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6289dbcd0..44e3741d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -148,6 +148,8 @@ jobs: with: submodules: recursive - uses: dtolnay/rust-toolchain@stable + with: + components: "clippy" - uses: Swatinem/rust-cache@v2 with: workspaces: temporalio/bridge -> target diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 790169afd..e8ddf38bd 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -1,7 +1,7 @@ import subprocess import sys from pathlib import Path -from typing import Optional, Tuple +from typing import Optional from google.protobuf.descriptor import Descriptor, FieldDescriptor @@ -89,6 +89,7 @@ def generate(self, roots: list[Descriptor]) -> str: from temporalio.api.common.v1.message_pb2 import Payload + class VisitorFunctions(abc.ABC): \"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence. diff --git a/temporalio/activity.py b/temporalio/activity.py index 72ee81ac0..d726b9ef2 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -470,6 +470,7 @@ class _CompleteAsyncError(BaseException): def payload_converter() -> temporalio.converter.PayloadConverter: """Get the payload converter for the current activity. + The returned converter has :py:class:`temporalio.converter.ActivitySerializationContext` set. This is often used for dynamic activities to convert payloads. """ return _Context.current().payload_converter diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 6fff9878c..8e20b670a 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -7,11 +7,9 @@ from dataclasses import dataclass from typing import ( - TYPE_CHECKING, Awaitable, Callable, List, - Mapping, MutableSequence, Optional, Sequence, @@ -20,7 +18,6 @@ Union, ) -import google.protobuf.internal.containers from typing_extensions import TypeAlias import temporalio.api.common.v1 @@ -35,12 +32,13 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter import temporalio.exceptions -from temporalio.api.common.v1.message_pb2 import Payload, Payloads -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.bridge._visitor import VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore +from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor @dataclass @@ -299,22 +297,22 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: async def decode_activation( - act: temporalio.bridge.proto.workflow_activation.WorkflowActivation, + activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, codec: temporalio.converter.PayloadCodec, decode_headers: bool, ) -> None: - """Decode the given activation with the codec.""" - await PayloadVisitor( + """Decode all payloads in the activation.""" + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(codec.decode), act) + ).visit(_Visitor(codec.decode), activation) async def encode_completion( - comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, + completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, codec: temporalio.converter.PayloadCodec, encode_headers: bool, ) -> None: - """Recursively encode the given completion with the codec.""" - await PayloadVisitor( + """Encode all payloads in the completion.""" + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(codec.encode), comp) + ).visit(_Visitor(codec.encode), completion) diff --git a/temporalio/client.py b/temporalio/client.py index f9735cfb2..20a9b3c6d 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6,6 +6,7 @@ import asyncio import copy import dataclasses +import functools import inspect import json import re @@ -40,7 +41,7 @@ import google.protobuf.json_format import google.protobuf.timestamp_pb2 from google.protobuf.internal.containers import MessageMap -from typing_extensions import Concatenate, Required, TypedDict +from typing_extensions import Concatenate, Required, Self, TypedDict import temporalio.api.common.v1 import temporalio.api.enums.v1 @@ -62,6 +63,12 @@ import temporalio.service import temporalio.workflow from temporalio.activity import ActivityCancellationDetails +from temporalio.converter import ( + DataConverter, + SerializationContext, + WithSerializationContext, + WorkflowSerializationContext, +) from temporalio.service import ( HttpConnectProxyConfig, KeepAliveConfig, @@ -1600,6 +1607,14 @@ def __init__( self._start_workflow_response = start_workflow_response self.__temporal_eagerly_started = False + @functools.cached_property + def _data_converter(self) -> temporalio.converter.DataConverter: + return self._client.data_converter.with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=self._client.namespace, workflow_id=self._id + ) + ) + @property def id(self) -> str: """ID for the workflow.""" @@ -1701,7 +1716,7 @@ async def result( break # Ignoring anything after the first response like TypeScript type_hints = [self._result_type] if self._result_type else None - results = await self._client.data_converter.decode_wrapper( + results = await self._data_converter.decode_wrapper( complete_attr.result, type_hints, ) @@ -1717,7 +1732,7 @@ async def result( hist_run_id = fail_attr.new_execution_run_id break raise WorkflowFailureError( - cause=await self._client.data_converter.decode_failure( + cause=await self._data_converter.decode_failure( fail_attr.failure ), ) @@ -1727,7 +1742,7 @@ async def result( cause=temporalio.exceptions.CancelledError( "Workflow cancelled", *( - await self._client.data_converter.decode_wrapper( + await self._data_converter.decode_wrapper( cancel_attr.details ) ), @@ -1739,7 +1754,7 @@ async def result( cause=temporalio.exceptions.TerminatedError( term_attr.reason or "Workflow terminated", *( - await self._client.data_converter.decode_wrapper( + await self._data_converter.decode_wrapper( term_attr.details ) ), @@ -2722,15 +2737,19 @@ class AsyncActivityIDReference: activity_id: str -class AsyncActivityHandle: +class AsyncActivityHandle(WithSerializationContext): """Handle representing an external activity for completion and heartbeat.""" def __init__( - self, client: Client, id_or_token: Union[AsyncActivityIDReference, bytes] + self, + client: Client, + id_or_token: Union[AsyncActivityIDReference, bytes], + data_converter_override: Optional[DataConverter] = None, ) -> None: """Create an async activity handle.""" self._client = client self._id_or_token = id_or_token + self._data_converter_override = data_converter_override async def heartbeat( self, @@ -2752,6 +2771,7 @@ async def heartbeat( details=details, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + data_converter_override=self._data_converter_override, ), ) @@ -2776,6 +2796,7 @@ async def complete( result=result, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + data_converter_override=self._data_converter_override, ), ) @@ -2803,6 +2824,7 @@ async def fail( last_heartbeat_details=last_heartbeat_details, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + data_converter_override=self._data_converter_override, ), ) @@ -2826,9 +2848,33 @@ async def report_cancellation( details=details, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, + data_converter_override=self._data_converter_override, ), ) + def with_context(self, context: SerializationContext) -> Self: + """Create a new AsyncActivityHandle with a different serialization context. + + Payloads received by the activity will be decoded and deserialized using a data converter + with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data + converter that makes use of this context then you can use this method to supply matching + context data to the data converter used to serialize and encode the outbound payloads. + """ + data_converter = self._client.data_converter.with_context(context) + if data_converter is self._client.data_converter: + return self + cls = type(self) + if cls.__init__ is not AsyncActivityHandle.__init__: + raise TypeError( + "If you have subclassed AsyncActivityHandle and overridden the __init__ method " + "then you must override with_context to return an instance of your class." + ) + return cls( + self._client, + self._id_or_token, + data_converter, + ) + @dataclass class WorkflowExecution: @@ -2837,9 +2883,6 @@ class WorkflowExecution: close_time: Optional[datetime] """When the workflow was closed if closed.""" - data_converter: temporalio.converter.DataConverter - """Data converter from when this description was created.""" - execution_time: Optional[datetime] """When this workflow run started or should start.""" @@ -2849,6 +2892,9 @@ class WorkflowExecution: id: str """ID for the workflow.""" + namespace: str + """Namespace for the workflow.""" + parent_id: Optional[str] """ID for the parent workflow if this was started as a child.""" @@ -2889,35 +2935,58 @@ class WorkflowExecution: workflow_type: str """Type name for the workflow.""" + _context_free_data_converter: temporalio.converter.DataConverter + + @property + def data_converter(self) -> temporalio.converter.DataConverter: + """Data converter for the workflow.""" + return self._context_free_data_converter.with_context( + WorkflowSerializationContext( + namespace=self.namespace, + workflow_id=self.id, + ) + ) + @classmethod def _from_raw_info( cls, info: temporalio.api.workflow.v1.WorkflowExecutionInfo, + namespace: str, converter: temporalio.converter.DataConverter, **additional_fields: Any, - ) -> WorkflowExecution: + ) -> Self: return cls( - close_time=info.close_time.ToDatetime().replace(tzinfo=timezone.utc) - if info.HasField("close_time") - else None, - data_converter=converter, - execution_time=info.execution_time.ToDatetime().replace(tzinfo=timezone.utc) - if info.HasField("execution_time") - else None, + close_time=( + info.close_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("close_time") + else None + ), + execution_time=( + info.execution_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("execution_time") + else None + ), history_length=info.history_length, id=info.execution.workflow_id, - parent_id=info.parent_execution.workflow_id - if info.HasField("parent_execution") - else None, - parent_run_id=info.parent_execution.run_id - if info.HasField("parent_execution") - else None, - root_id=info.root_execution.workflow_id - if info.HasField("root_execution") - else None, - root_run_id=info.root_execution.run_id - if info.HasField("root_execution") - else None, + namespace=namespace, + parent_id=( + info.parent_execution.workflow_id + if info.HasField("parent_execution") + else None + ), + parent_run_id=( + info.parent_execution.run_id + if info.HasField("parent_execution") + else None + ), + root_id=( + info.root_execution.workflow_id + if info.HasField("root_execution") + else None + ), + root_run_id=( + info.root_execution.run_id if info.HasField("root_execution") else None + ), raw_info=info, run_id=info.execution.run_id, search_attributes=temporalio.converter.decode_search_attributes( @@ -2930,6 +2999,7 @@ def _from_raw_info( info.search_attributes ), workflow_type=info.type.name, + _context_free_data_converter=converter, **additional_fields, ) @@ -3035,11 +3105,13 @@ async def _decode_metadata(self) -> None: @staticmethod async def _from_raw_description( description: temporalio.api.workflowservice.v1.DescribeWorkflowExecutionResponse, + namespace: str, converter: temporalio.converter.DataConverter, ) -> WorkflowExecutionDescription: - return WorkflowExecutionDescription._from_raw_info( # type: ignore + return WorkflowExecutionDescription._from_raw_info( description.workflow_execution_info, - converter, + namespace=namespace, + converter=converter, raw_description=description, ) @@ -3189,8 +3261,11 @@ async def fetch_next_page(self, *, page_size: Optional[int] = None) -> None: metadata=self._input.rpc_metadata, timeout=self._input.rpc_timeout, ) + self._current_page = [ - WorkflowExecution._from_raw_info(v, self._client.data_converter) + WorkflowExecution._from_raw_info( + v, self._client.namespace, self._client.data_converter + ) for v in resp.executions ] self._current_page_index = 0 @@ -4148,37 +4223,47 @@ async def _to_proto( priority: Optional[temporalio.api.common.v1.Priority] = None if self.priority: priority = self.priority._to_proto() + data_converter = client.data_converter.with_context( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=self.id, + ) + ) action = temporalio.api.schedule.v1.ScheduleAction( start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( workflow_id=self.id, workflow_type=temporalio.api.common.v1.WorkflowType(name=self.workflow), task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=self.task_queue), - input=None - if not self.args - else temporalio.api.common.v1.Payloads( - payloads=[ - a - if isinstance(a, temporalio.api.common.v1.Payload) - else (await client.data_converter.encode([a]))[0] - for a in self.args - ] + input=( + temporalio.api.common.v1.Payloads( + payloads=[ + a + if isinstance(a, temporalio.api.common.v1.Payload) + else (await data_converter.encode([a]))[0] + for a in self.args + ] + ) + if self.args + else None ), workflow_execution_timeout=execution_timeout, workflow_run_timeout=run_timeout, workflow_task_timeout=task_timeout, retry_policy=retry_policy, - memo=None - if not self.memo - else temporalio.api.common.v1.Memo( - fields={ - k: v - if isinstance(v, temporalio.api.common.v1.Payload) - else (await client.data_converter.encode([v]))[0] - for k, v in self.memo.items() - }, + memo=( + temporalio.api.common.v1.Memo( + fields={ + k: v + if isinstance(v, temporalio.api.common.v1.Payload) + else (await data_converter.encode([v]))[0] + for k, v in self.memo.items() + }, + ) + if self.memo + else None ), user_metadata=await _encode_user_metadata( - client.data_converter, self.static_summary, self.static_details + data_converter, self.static_summary, self.static_details ), priority=priority, ), @@ -4995,6 +5080,15 @@ def __init__( self._result_type = result_type self._known_outcome = known_outcome + @functools.cached_property + def _data_converter(self) -> temporalio.converter.DataConverter: + return self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=self.workflow_id, + ) + ) + @property def id(self) -> str: """ID of this Update request.""" @@ -5041,14 +5135,12 @@ async def result( assert self._known_outcome if self._known_outcome.HasField("failure"): raise WorkflowUpdateFailedError( - await self._client.data_converter.decode_failure( - self._known_outcome.failure - ), + await self._data_converter.decode_failure(self._known_outcome.failure), ) if not self._known_outcome.success.payloads: return None # type: ignore type_hints = [self._result_type] if self._result_type else None - results = await self._client.data_converter.decode( + results = await self._data_converter.decode( self._known_outcome.success.payloads, type_hints ) if not results: @@ -5454,6 +5546,7 @@ class HeartbeatAsyncActivityInput: details: Sequence[Any] rpc_metadata: Mapping[str, Union[str, bytes]] rpc_timeout: Optional[timedelta] + data_converter_override: Optional[DataConverter] = None @dataclass @@ -5464,6 +5557,7 @@ class CompleteAsyncActivityInput: result: Optional[Any] rpc_metadata: Mapping[str, Union[str, bytes]] rpc_timeout: Optional[timedelta] + data_converter_override: Optional[DataConverter] = None @dataclass @@ -5475,6 +5569,7 @@ class FailAsyncActivityInput: last_heartbeat_details: Sequence[Any] rpc_metadata: Mapping[str, Union[str, bytes]] rpc_timeout: Optional[timedelta] + data_converter_override: Optional[DataConverter] = None @dataclass @@ -5485,6 +5580,7 @@ class ReportCancellationAsyncActivityInput: details: Sequence[Any] rpc_metadata: Mapping[str, Union[str, bytes]] rpc_timeout: Optional[timedelta] + data_converter_override: Optional[DataConverter] = None @dataclass @@ -5900,12 +5996,18 @@ async def _build_signal_with_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: assert input.start_signal + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( signal_name=input.start_signal ) if input.start_signal_args: req.signal_input.payloads.extend( - await self._client.data_converter.encode(input.start_signal_args) + await data_converter.encode(input.start_signal_args) ) await self._populate_start_workflow_execution_request(req, input) return req @@ -5925,14 +6027,18 @@ async def _populate_start_workflow_execution_request( ], input: Union[StartWorkflowInput, UpdateWithStartStartWorkflowInput], ) -> None: + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) req.namespace = self._client.namespace req.workflow_id = input.id req.workflow_type.name = input.workflow req.task_queue.name = input.task_queue if input.args: - req.input.payloads.extend( - await self._client.data_converter.encode(input.args) - ) + req.input.payloads.extend(await data_converter.encode(input.args)) if input.execution_timeout is not None: req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) if input.run_timeout is not None: @@ -5955,15 +6061,13 @@ async def _populate_start_workflow_execution_request( req.cron_schedule = input.cron_schedule if input.memo is not None: for k, v in input.memo.items(): - req.memo.fields[k].CopyFrom( - (await self._client.data_converter.encode([v]))[0] - ) + req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0]) if input.search_attributes is not None: temporalio.converter.encode_search_attributes( input.search_attributes, req.search_attributes ) metadata = await _encode_user_metadata( - self._client.data_converter, input.static_summary, input.static_details + data_converter, input.static_summary, input.static_details ) if metadata is not None: req.user_metadata.CopyFrom(metadata) @@ -6009,7 +6113,13 @@ async def describe_workflow( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ), - self._client.data_converter, + namespace=self._client.namespace, + converter=self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ), ) def fetch_workflow_history_events( @@ -6038,6 +6148,12 @@ async def count_workflows( ) async def query_workflow(self, input: QueryWorkflowInput) -> Any: + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( namespace=self._client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( @@ -6053,7 +6169,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: req.query.query_type = input.query if input.args: req.query.query_args.payloads.extend( - await self._client.data_converter.encode(input.args) + await data_converter.encode(input.args) ) if input.headers is not None: await self._apply_headers(input.headers, req.query.header.fields) @@ -6077,9 +6193,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: if not resp.query_result.payloads: return None type_hints = [input.ret_type] if input.ret_type else None - results = await self._client.data_converter.decode( - resp.query_result.payloads, type_hints - ) + results = await data_converter.decode(resp.query_result.payloads, type_hints) if not results: return None elif len(results) > 1: @@ -6087,6 +6201,12 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: return results[0] async def signal_workflow(self, input: SignalWorkflowInput) -> None: + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( @@ -6098,9 +6218,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None: request_id=str(uuid.uuid4()), ) if input.args: - req.input.payloads.extend( - await self._client.data_converter.encode(input.args) - ) + req.input.payloads.extend(await data_converter.encode(input.args)) if input.headers is not None: await self._apply_headers(input.headers, req.header.fields) await self._client.workflow_service.signal_workflow_execution( @@ -6108,6 +6226,12 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None: ) async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( @@ -6119,9 +6243,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: first_execution_run_id=input.first_execution_run_id or "", ) if input.args: - req.details.payloads.extend( - await self._client.data_converter.encode(input.args) - ) + req.details.payloads.extend(await data_converter.encode(input.args)) await self._client.workflow_service.terminate_workflow_execution( req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) @@ -6178,6 +6300,12 @@ async def _build_update_workflow_execution_request( input: Union[StartWorkflowUpdateInput, UpdateWithStartUpdateWorkflowInput], workflow_id: str, ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=workflow_id, + ) + ) run_id, first_execution_run_id = ( ( input.run_id, @@ -6210,7 +6338,7 @@ async def _build_update_workflow_execution_request( ) if input.args: req.request.input.args.payloads.extend( - await self._client.data_converter.encode(input.args) + await data_converter.encode(input.args) ) if input.headers is not None: await self._apply_headers(input.headers, req.request.input.header.fields) @@ -6354,10 +6482,11 @@ async def _start_workflow_update_with_start( async def heartbeat_async_activity( self, input: HeartbeatAsyncActivityInput ) -> None: + data_converter = input.data_converter_override or self._client.data_converter details = ( None if not input.details - else await self._client.data_converter.encode_wrapper(input.details) + else await data_converter.encode_wrapper(input.details) ) if isinstance(input.id_or_token, AsyncActivityIDReference): resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( @@ -6408,10 +6537,11 @@ async def heartbeat_async_activity( ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: + data_converter = input.data_converter_override or self._client.data_converter result = ( None if input.result is temporalio.common._arg_unset - else await self._client.data_converter.encode_wrapper([input.result]) + else await data_converter.encode_wrapper([input.result]) ) if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_completed_by_id( @@ -6441,14 +6571,14 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No ) async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: + data_converter = input.data_converter_override or self._client.data_converter + failure = temporalio.api.failure.v1.Failure() - await self._client.data_converter.encode_failure(input.error, failure) + await data_converter.encode_failure(input.error, failure) last_heartbeat_details = ( - None - if not input.last_heartbeat_details - else await self._client.data_converter.encode_wrapper( - input.last_heartbeat_details - ) + await data_converter.encode_wrapper(input.last_heartbeat_details) + if input.last_heartbeat_details + else None ) if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_failed_by_id( @@ -6482,10 +6612,11 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: async def report_cancellation_async_activity( self, input: ReportCancellationAsyncActivityInput ) -> None: + data_converter = input.data_converter_override or self._client.data_converter details = ( None if not input.details - else await self._client.data_converter.encode_wrapper(input.details) + else await data_converter.encode_wrapper(input.details) ) if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_canceled_by_id( diff --git a/temporalio/converter.py b/temporalio/converter.py index a9f8c0c98..29eb35566 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -5,6 +5,7 @@ import collections import collections.abc import dataclasses +import functools import inspect import json import sys @@ -43,6 +44,7 @@ import google.protobuf.symbol_database import nexusrpc import typing_extensions +from typing_extensions import Self import temporalio.api.common.v1 import temporalio.api.enums.v1 @@ -65,6 +67,106 @@ logger = getLogger(__name__) +class SerializationContext(ABC): + """Base serialization context. + + Provides contextual information during serialization and deserialization operations. + + Examples: + In client code, when starting a workflow, or sending a signal/update/query to a workflow, + or receiving the result of an update/query, or handling an exception from a workflow, the + context type is :py:class:`WorkflowSerializationContext` and the workflow ID set of the + target workflow will be set in the context. + + In workflow code, when operating on a payload being sent/received to/from a child workflow, + or handling an exception from a child workflow, the context type is + :py:class:`WorkflowSerializationContext` and the workflow ID is that of the child workflow, + not of the currently executing (i.e. parent) workflow. + + In workflow code, when operating on a payload to be sent/received to/from an activity, the + context type is :py:class:`ActivitySerializationContext` and the workflow ID is that of the + currently-executing workflow. ActivitySerializationContext is also set on data converter + operations in the activity context. + """ + + pass + + +@dataclass(frozen=True) +class BaseWorkflowSerializationContext(SerializationContext): + """Base serialization context shared by workflow and activity serialization contexts.""" + + namespace: str + workflow_id: str + + +@dataclass(frozen=True) +class WorkflowSerializationContext(BaseWorkflowSerializationContext): + """Serialization context for workflows. + + See :py:class:`SerializationContext` for more details. + + Attributes: + namespace: The namespace the workflow is running in. + workflow_id: The ID of the workflow. Note that this is the ID of the workflow of which the + payload being operated on is an input or output. Note also that when creating/describing + schedules, this may be the workflow ID prefix as configured, not the final workflow ID + when the workflow is created by the schedule. + """ + + pass + + +@dataclass(frozen=True) +class ActivitySerializationContext(BaseWorkflowSerializationContext): + """Serialization context for activities. + + See :py:class:`SerializationContext` for more details. + + Attributes: + namespace: Workflow/activity namespace. + workflow_id: Workflow ID. Note, when creating/describing schedules, + this may be the workflow ID prefix as configured, not the final workflow ID when the + workflow is created by the schedule. + workflow_type: Workflow Type. + activity_type: Activity Type. + activity_task_queue: Activity task queue. + is_local: Whether the activity is a local activity. + """ + + workflow_type: str + activity_type: str + activity_task_queue: str + is_local: bool + + +class WithSerializationContext(ABC): + """Interface for classes that can use serialization context. + + The following classes may implement this interface: + - :py:class:`PayloadConverter` + - :py:class:`PayloadCodec` + - :py:class:`FailureConverter` + - :py:class:`EncodingPayloadConverter` + + During data converter operations (encoding/decoding, serialization/deserialization, and failure + conversion), instances of classes implementing this interface will be replaced by the result of + calling with_context(context). This allows overridden methods (encode/decode, + to_payload/from_payload, etc) to use the context. + """ + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of this object configured to use the given context. + + Args: + context: The serialization context to use. + + Returns: + A new instance configured with the context. + """ + raise NotImplementedError() + + class PayloadConverter(ABC): """Base payload converter to/from multiple payloads/values.""" @@ -232,7 +334,7 @@ def from_payload( raise NotImplementedError -class CompositePayloadConverter(PayloadConverter): +class CompositePayloadConverter(PayloadConverter, WithSerializationContext): """Composite payload converter that delegates to a list of encoding payload converters. Encoding/decoding are attempted on each payload converter successively until @@ -250,7 +352,9 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None: Args: converters: Payload converters to delegate to, in order. """ - # Insertion order preserved here since Python 3.7 + self._set_converters(*converters) + + def _set_converters(self, *converters: EncodingPayloadConverter) -> None: self.converters = {c.encoding.encode(): c for c in converters} def to_payloads( @@ -315,6 +419,44 @@ def from_payloads( ) from err return values + def with_context(self, context: SerializationContext) -> Self: + """Return a new instance with context set on the component converters. + + If none of the component converters returned new instances, return self. + """ + converters = self.get_converters_with_context(context) + if converters is None: + return self + new_instance = type(self)() # Must have a nullary constructor + new_instance._set_converters(*converters) + return new_instance + + def get_converters_with_context( + self, context: SerializationContext + ) -> Optional[list[EncodingPayloadConverter]]: + """Return converter instances with context set. + + If no converter uses context, return None. + """ + if not self._any_converter_takes_context: + return None + converters: list[EncodingPayloadConverter] = [] + any_with_context = False + for c in self.converters.values(): + if isinstance(c, WithSerializationContext): + converters.append(c.with_context(context)) + any_with_context |= converters[-1] is not c + else: + converters.append(c) + + return converters if any_with_context else None + + @functools.cached_property + def _any_converter_takes_context(self) -> bool: + return any( + isinstance(c, WithSerializationContext) for c in self.converters.values() + ) + class DefaultPayloadConverter(CompositePayloadConverter): """Default payload converter compatible with other Temporal SDKs. @@ -1108,7 +1250,7 @@ def __init__(self) -> None: @dataclass(frozen=True) -class DataConverter: +class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. This combines :py:class:`PayloadConverter` which converts values with @@ -1212,6 +1354,32 @@ async def decode_failure( await self.payload_codec.decode_failure(failure) return self.failure_converter.from_failure(failure, self.payload_converter) + def with_context(self, context: SerializationContext) -> Self: + """Return an instance with context set on the component converters.""" + payload_converter = self.payload_converter + payload_codec = self.payload_codec + failure_converter = self.failure_converter + if isinstance(payload_converter, WithSerializationContext): + payload_converter = payload_converter.with_context(context) + if isinstance(payload_codec, WithSerializationContext): + payload_codec = payload_codec.with_context(context) + if isinstance(failure_converter, WithSerializationContext): + failure_converter = failure_converter.with_context(context) + if all( + new is orig + for new, orig in [ + (payload_converter, self.payload_converter), + (payload_codec, self.payload_codec), + (failure_converter, self.failure_converter), + ] + ): + return self + cloned = dataclasses.replace(self) + object.__setattr__(cloned, "payload_converter", payload_converter) + object.__setattr__(cloned, "payload_codec", payload_codec) + object.__setattr__(cloned, "failure_converter", failure_converter) + return cloned + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 4ccb56ca6..44bfb6910 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -252,6 +252,18 @@ async def _heartbeat_async( if details is None: return + data_converter = self._data_converter + if activity.info: + context = temporalio.converter.ActivitySerializationContext( + namespace=activity.info.workflow_namespace, + workflow_id=activity.info.workflow_id, + workflow_type=activity.info.workflow_type, + activity_type=activity.info.activity_type, + activity_task_queue=self._task_queue, + is_local=activity.info.is_local, + ) + data_converter = data_converter.with_context(context) + # Perform the heartbeat try: heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue] @@ -259,7 +271,7 @@ async def _heartbeat_async( ) if details: # Convert to core payloads - heartbeat.details.extend(await self._data_converter.encode(details)) + heartbeat.details.extend(await data_converter.encode(details)) logger.debug("Recording heartbeat with details %s", details) self._bridge_worker().record_activity_heartbeat(heartbeat) except Exception as err: @@ -293,9 +305,21 @@ async def _handle_start_activity_task( completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue] task_token=task_token ) + # Create serialization context for the activity + context = temporalio.converter.ActivitySerializationContext( + namespace=start.workflow_namespace, + workflow_id=start.workflow_execution.workflow_id, + workflow_type=start.workflow_type, + activity_type=start.activity_type, + activity_task_queue=self._task_queue, + is_local=start.is_local, + ) + data_converter = self._data_converter.with_context(context) try: - result = await self._execute_activity(start, running_activity, task_token) - [payload] = await self._data_converter.encode([result]) + result = await self._execute_activity( + start, running_activity, task_token, data_converter + ) + [payload] = await data_converter.encode([result]) completion.result.completed.result.CopyFrom(payload) except BaseException as err: try: @@ -313,7 +337,7 @@ async def _handle_start_activity_task( temporalio.activity.logger.warning( f"Completing as failure during heartbeat with error of type {type(err)}: {err}", ) - await self._data_converter.encode_failure( + await data_converter.encode_failure( err, completion.result.failed.failure ) elif ( @@ -327,7 +351,7 @@ async def _handle_start_activity_task( temporalio.activity.logger.warning( "Completing as failure due to unhandled cancel error produced by activity pause", ) - await self._data_converter.encode_failure( + await data_converter.encode_failure( temporalio.exceptions.ApplicationError( type="ActivityPause", message="Unhandled activity cancel error produced by activity pause", @@ -345,7 +369,7 @@ async def _handle_start_activity_task( temporalio.activity.logger.warning( "Completing as failure due to unhandled cancel error produced by activity reset", ) - await self._data_converter.encode_failure( + await data_converter.encode_failure( temporalio.exceptions.ApplicationError( type="ActivityReset", message="Unhandled activity cancel error produced by activity reset", @@ -360,7 +384,7 @@ async def _handle_start_activity_task( and running_activity.cancelled_by_request ): temporalio.activity.logger.debug("Completing as cancelled") - await self._data_converter.encode_failure( + await data_converter.encode_failure( # TODO(cretz): Should use some other message? temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, @@ -386,7 +410,7 @@ async def _handle_start_activity_task( exc_info=True, extra={"__temporal_error_identifier": "ActivityFailure"}, ) - await self._data_converter.encode_failure( + await data_converter.encode_failure( err, completion.result.failed.failure ) # For broken executors, we have to fail the entire worker @@ -428,6 +452,7 @@ async def _execute_activity( start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue] running_activity: _RunningActivity, task_token: bytes, + data_converter: temporalio.converter.DataConverter, ) -> Any: """Invoke the user's activity function. @@ -501,9 +526,7 @@ async def _execute_activity( args = ( [] if not start.input - else await self._data_converter.decode( - start.input, type_hints=arg_types - ) + else await data_converter.decode(start.input, type_hints=arg_types) ) except Exception as err: raise temporalio.exceptions.ApplicationError( @@ -519,7 +542,7 @@ async def _execute_activity( heartbeat_details = ( [] if not start.heartbeat_details - else await self._data_converter.decode(start.heartbeat_details) + else await data_converter.decode(start.heartbeat_details) ) except Exception as err: raise temporalio.exceptions.ApplicationError( @@ -563,11 +586,9 @@ async def _execute_activity( else None, ) - if self._encode_headers and self._data_converter.payload_codec is not None: + if self._encode_headers and data_converter.payload_codec is not None: for payload in start.header_fields.values(): - new_payload = ( - await self._data_converter.payload_codec.decode([payload]) - )[0] + new_payload = (await data_converter.payload_codec.decode([payload]))[0] payload.CopyFrom(new_payload) running_activity.info = info @@ -591,7 +612,7 @@ async def _execute_activity( if not running_activity.cancel_thread_raiser else running_activity.cancel_thread_raiser.shielded ), - payload_converter_class_or_instance=self._data_converter.payload_converter, + payload_converter_class_or_instance=data_converter.payload_converter, runtime_metric_meter=None if sync_non_threaded else self._metric_meter, client=self._client if not running_activity.sync else None, cancellation_details=running_activity.cancellation_details, diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py new file mode 100644 index 000000000..450445999 --- /dev/null +++ b/temporalio/worker/_command_aware_visitor.py @@ -0,0 +1,163 @@ +"""Visitor that sets command context during payload traversal.""" + +import contextvars +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Iterator, Optional + +from temporalio.api.enums.v1.command_type_pb2 import CommandType +from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + ResolveActivity, + ResolveChildWorkflowExecution, + ResolveChildWorkflowExecutionStart, + ResolveNexusOperation, + ResolveNexusOperationStart, + ResolveRequestCancelExternalWorkflow, + ResolveSignalExternalWorkflow, +) +from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( + ScheduleActivity, + ScheduleLocalActivity, + ScheduleNexusOperation, + SignalExternalWorkflowExecution, + StartChildWorkflowExecution, +) + + +@dataclass(frozen=True) +class CommandInfo: + """Information identifying a specific command instance.""" + + command_type: CommandType.ValueType + command_seq: int + + +current_command_info: contextvars.ContextVar[Optional[CommandInfo]] = ( + contextvars.ContextVar("current_command_info", default=None) +) + + +class CommandAwarePayloadVisitor(PayloadVisitor): + """Payload visitor that sets command context during traversal. + + Override methods are explicitly defined for workflow commands and + activation jobs that have both a 'seq' field and payloads to visit. + """ + + # Workflow commands with payloads + async def _visit_coresdk_workflow_commands_ScheduleActivity( + self, fs: VisitorFunctions, o: ScheduleActivity + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, o.seq): + await super()._visit_coresdk_workflow_commands_ScheduleActivity(fs, o) + + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( + self, fs: VisitorFunctions, o: ScheduleLocalActivity + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, o.seq): + await super()._visit_coresdk_workflow_commands_ScheduleLocalActivity(fs, o) + + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( + self, fs: VisitorFunctions, o: StartChildWorkflowExecution + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_commands_StartChildWorkflowExecution( + fs, o + ) + + async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + self, fs: VisitorFunctions, o: SignalExternalWorkflowExecution + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + fs, o + ) + + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( + self, fs: VisitorFunctions, o: ScheduleNexusOperation + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) + + # Workflow activation jobs with payloads + async def _visit_coresdk_workflow_activation_ResolveActivity( + self, fs: VisitorFunctions, o: ResolveActivity + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, o.seq): + await super()._visit_coresdk_workflow_activation_ResolveActivity(fs, o) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + self, fs: VisitorFunctions, o: ResolveChildWorkflowExecutionStart + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + fs, o + ) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + self, fs: VisitorFunctions, o: ResolveChildWorkflowExecution + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + fs, o + ) + + async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + self, fs: VisitorFunctions, o: ResolveSignalExternalWorkflow + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + fs, o + ) + + async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + self, fs: VisitorFunctions, o: ResolveRequestCancelExternalWorkflow + ) -> None: + with current_command( + CommandType.COMMAND_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION, o.seq + ): + await super()._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + fs, o + ) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( + self, fs: VisitorFunctions, o: ResolveNexusOperationStart + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + await super()._visit_coresdk_workflow_activation_ResolveNexusOperationStart( + fs, o + ) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperation( + self, fs: VisitorFunctions, o: ResolveNexusOperation + ) -> None: + with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + await super()._visit_coresdk_workflow_activation_ResolveNexusOperation( + fs, o + ) + + +@contextmanager +def current_command( + command_type: CommandType.ValueType, command_seq: int +) -> Iterator[None]: + """Context manager for setting command info.""" + token = current_command_info.set( + CommandInfo(command_type=command_type, command_seq=command_seq) + ) + try: + yield + finally: + if token: + current_command_info.reset(token) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 1e178f015..6e7c254aa 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -8,6 +8,7 @@ import os import sys import threading +from dataclasses import dataclass from datetime import timezone from types import TracebackType from typing import ( @@ -24,6 +25,7 @@ import temporalio.activity import temporalio.api.common.v1 +import temporalio.bridge._visitor import temporalio.bridge.client import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion @@ -35,6 +37,7 @@ import temporalio.exceptions import temporalio.workflow +from . import _command_aware_visitor from ._interceptor import ( Interceptor, WorkflowInboundInterceptor, @@ -253,66 +256,85 @@ async def _handle_activation( temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion() ) completion.successful.SetInParent() + workflow = None + data_converter = self._data_converter try: - # Decode the activation if there's a codec and not cache remove job - if self._data_converter.payload_codec: - await temporalio.bridge.worker.decode_activation( - act, - self._data_converter.payload_codec, - decode_headers=self._encode_headers, - ) - if LOG_PROTOS: logger.debug("Received workflow activation:\n%s", act) - # If the workflow is not running yet, create it workflow = self._running_workflows.get(act.run_id) if not workflow: - # Must have a initialize job to create instance if not init_job: raise RuntimeError( "Missing initialize workflow, workflow could have unexpectedly been removed from cache" ) + workflow_id = init_job.workflow_id + else: + workflow_id = workflow.workflow_id + if init_job: + # Should never happen + logger.warning( + "Cache already exists for activation with initialize job" + ) + + workflow_context = temporalio.converter.WorkflowSerializationContext( + namespace=self._namespace, + workflow_id=workflow_id, + ) + data_converter = self._data_converter.with_context(workflow_context) + if self._data_converter.payload_codec: + assert data_converter.payload_codec + if not workflow: + payload_codec = data_converter.payload_codec + else: + payload_codec = _CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=workflow_context, + ) + await temporalio.bridge.worker.decode_activation( + act, + payload_codec, + decode_headers=self._encode_headers, + ) + if not workflow: + assert init_job workflow = _RunningWorkflow( - self._create_workflow_instance(act, init_job) + self._create_workflow_instance(act, init_job), + workflow_id, ) self._running_workflows[act.run_id] = workflow - elif init_job: - # This should never happen - logger.warning( - "Cache already exists for activation with initialize job" - ) # Run activation in separate thread so we can check if it's # deadlocked - if workflow: - activate_task = asyncio.get_running_loop().run_in_executor( - self._workflow_task_executor, - workflow.activate, - act, - ) + activate_task = asyncio.get_running_loop().run_in_executor( + self._workflow_task_executor, + workflow.activate, + act, + ) - # Run activation task with deadlock timeout - try: - completion = await asyncio.wait_for( - activate_task, self._deadlock_timeout_seconds - ) - except asyncio.TimeoutError: - # Need to create the deadlock exception up here so it - # captures the trace now instead of later after we may have - # interrupted it - deadlock_exc = _DeadlockError.from_deadlocked_workflow( - workflow.instance, self._deadlock_timeout_seconds - ) - # When we deadlock, we will raise an exception to fail - # the task. But before we do that, we want to try to - # interrupt the thread and put this activation task on - # the workflow so that the successive eviction can wait - # on it before trying to evict. - workflow.attempt_deadlock_interruption() - # Set the task and raise - workflow.deadlocked_activation_task = activate_task - raise deadlock_exc from None + # Run activation task with deadlock timeout + try: + completion = await asyncio.wait_for( + activate_task, self._deadlock_timeout_seconds + ) + except asyncio.TimeoutError: + # Need to create the deadlock exception up here so it + # captures the trace now instead of later after we may have + # interrupted it + deadlock_exc = _DeadlockError.from_deadlocked_workflow( + workflow.instance, self._deadlock_timeout_seconds + ) + # When we deadlock, we will raise an exception to fail + # the task. But before we do that, we want to try to + # interrupt the thread and put this activation task on + # the workflow so that the successive eviction can wait + # on it before trying to evict. + workflow.attempt_deadlock_interruption() + # Set the task and raise + workflow.deadlocked_activation_task = activate_task + raise deadlock_exc from None except Exception as err: if isinstance(err, _DeadlockError): @@ -322,12 +344,11 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - # Set completion failure completion.failed.failure.SetInParent() try: - self._data_converter.failure_converter.to_failure( + data_converter.failure_converter.to_failure( err, - self._data_converter.payload_converter, + data_converter.payload_converter, completion.failed.failure, ) except Exception as inner_err: @@ -339,15 +360,24 @@ async def _handle_activation( f"Failed converting activation exception: {inner_err}" ) - # Always set the run ID on the completion completion.run_id = act.run_id - # Encode the completion if there's a codec and not cache remove job - if self._data_converter.payload_codec: + # Encode completion + if self._data_converter.payload_codec and workflow: + assert data_converter.payload_codec + payload_codec = _CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=temporalio.converter.WorkflowSerializationContext( + namespace=self._namespace, + workflow_id=workflow.workflow_id, + ), + ) try: await temporalio.bridge.worker.encode_completion( completion, - self._data_converter.payload_codec, + payload_codec, encode_headers=self._encode_headers, ) except Exception as err: @@ -667,8 +697,9 @@ def _gen_tb_helper( class _RunningWorkflow: - def __init__(self, instance: WorkflowInstance): + def __init__(self, instance: WorkflowInstance, workflow_id: str): self.instance = instance + self.workflow_id = workflow_id self.deadlocked_activation_task: Optional[Awaitable] = None self._deadlock_can_be_interrupted_lock = threading.Lock() self._deadlock_can_be_interrupted = False @@ -698,5 +729,47 @@ def attempt_deadlock_interruption(self) -> None: ) +@dataclass(frozen=True) +class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec): + """A payload codec that sets serialization context for the command associated with each payload. + + This codec responds to the context variable set by + :py:class:`_command_aware_visitor.CommandAwarePayloadVisitor`. + """ + + instance: WorkflowInstance + context_free_payload_codec: temporalio.converter.PayloadCodec + workflow_context_payload_codec: temporalio.converter.PayloadCodec + workflow_context: temporalio.converter.WorkflowSerializationContext + + async def encode( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ) -> List[temporalio.api.common.v1.Payload]: + return await self._get_current_command_codec().encode(payloads) + + async def decode( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ) -> List[temporalio.api.common.v1.Payload]: + return await self._get_current_command_codec().decode(payloads) + + def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec: + if not isinstance( + self.context_free_payload_codec, + temporalio.converter.WithSerializationContext, + ): + return self.context_free_payload_codec + + if context := self.instance.get_serialization_context( + _command_aware_visitor.current_command_info.get(), + ): + if context == self.workflow_context: + return self.workflow_context_payload_codec + return self.context_free_payload_codec.with_context(context) + + return self.context_free_payload_codec + + class _InterruptDeadlockError(BaseException): pass diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 118966b34..44eb443ff 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -65,6 +65,7 @@ from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure +from . import _command_aware_visitor from ._interceptor import ( ContinueAsNewInput, ExecuteWorkflowInput, @@ -168,6 +169,22 @@ def activate( """ raise NotImplementedError + @abstractmethod + def get_serialization_context( + self, + command_info: Optional[_command_aware_visitor.CommandInfo], + ) -> Optional[temporalio.converter.SerializationContext]: + """Return appropriate serialization context. + + Args: + command_info: Optional information identifying the associated command. If set, the payload + codec will have serialization context set appropriately for that command. + + Returns: + The serialization context, or None if no context should be set. + """ + raise NotImplementedError + def get_thread_id(self) -> Optional[int]: """Return the thread identifier that this workflow is running on. @@ -207,11 +224,22 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: # No init for AbstractEventLoop WorkflowInstance.__init__(self) temporalio.workflow._Runtime.__init__(self) - self._payload_converter = det.payload_converter_class() - self._failure_converter = det.failure_converter_class() self._defn = det.defn self._workflow_input: Optional[ExecuteWorkflowInput] = None self._info = det.info + self._context_free_payload_converter = det.payload_converter_class() + self._context_free_failure_converter = det.failure_converter_class() + workflow_context = temporalio.converter.WorkflowSerializationContext( + namespace=det.info.namespace, + workflow_id=det.info.workflow_id, + ) + self._workflow_context_payload_converter = self._payload_converter_with_context( + workflow_context + ) + self._workflow_context_failure_converter = self._failure_converter_with_context( + workflow_context + ) + self._extern_functions = det.extern_functions self._disable_eager_activity_execution = det.disable_eager_activity_execution self._worker_level_failure_exception_types = ( @@ -236,8 +264,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._pending_activities: Dict[int, _ActivityHandle] = {} self._pending_child_workflows: Dict[int, _ChildWorkflowHandle] = {} self._pending_nexus_operations: Dict[int, _NexusOperationHandle] = {} - self._pending_external_signals: Dict[int, asyncio.Future] = {} - self._pending_external_cancels: Dict[int, asyncio.Future] = {} + self._pending_external_signals: Dict[int, Tuple[asyncio.Future, str]] = {} + self._pending_external_cancels: Dict[int, Tuple[asyncio.Future, str]] = {} # Keyed by type self._curr_seqs: Dict[str, int] = {} # TODO(cretz): Any concerns about not sharing this? Maybe the types I @@ -466,9 +494,9 @@ def activate( # Set completion failure self._current_completion.failed.failure.SetInParent() try: - self._failure_converter.to_failure( + self._workflow_context_failure_converter.to_failure( activation_err, - self._payload_converter, + self._workflow_context_payload_converter, self._current_completion.failed.failure, ) except Exception as inner_err: @@ -610,7 +638,9 @@ async def run_update() -> None: # Run the handler success = await self._inbound.handle_update_handler(handler_input) - result_payloads = self._payload_converter.to_payloads([success]) + result_payloads = self._workflow_context_payload_converter.to_payloads( + [success] + ) if len(result_payloads) != 1: raise ValueError( f"Expected 1 result payload, got {len(result_payloads)}" @@ -642,9 +672,9 @@ async def run_update() -> None: job.protocol_instance_id ) command.update_response.rejected.SetInParent() - self._failure_converter.to_failure( + self._workflow_context_failure_converter.to_failure( err, - self._payload_converter, + self._workflow_context_payload_converter, command.update_response.rejected, ) else: @@ -706,7 +736,9 @@ async def run_query() -> None: headers=job.headers, ) success = await self._inbound.handle_query(input) - result_payloads = self._payload_converter.to_payloads([success]) + result_payloads = ( + self._workflow_context_payload_converter.to_payloads([success]) + ) if len(result_payloads) != 1: raise ValueError( f"Expected 1 result payload, got {len(result_payloads)}" @@ -718,9 +750,9 @@ async def run_query() -> None: try: command = self._add_command() command.respond_to_query.query_id = job.query_id - self._failure_converter.to_failure( + self._workflow_context_failure_converter.to_failure( err, - self._payload_converter, + self._workflow_context_payload_converter, command.respond_to_query.failed, ) except Exception as inner_err: @@ -754,6 +786,20 @@ def _apply_resolve_activity( handle = self._pending_activities.pop(job.seq, None) if not handle: raise RuntimeError(f"Failed finding activity handle for sequence {job.seq}") + activity_context = temporalio.converter.ActivitySerializationContext( + namespace=self._info.namespace, + workflow_id=self._info.workflow_id, + workflow_type=self._info.workflow_type, + activity_type=handle._input.activity, + activity_task_queue=( + handle._input.task_queue or self._info.task_queue + if isinstance(handle._input, StartActivityInput) + else self._info.task_queue + ), + is_local=isinstance(handle._input, StartLocalActivityInput), + ) + payload_converter = self._payload_converter_with_context(activity_context) + failure_converter = self._failure_converter_with_context(activity_context) if job.result.HasField("completed"): ret: Optional[Any] = None if job.result.completed.HasField("result"): @@ -761,19 +807,20 @@ def _apply_resolve_activity( ret_vals = self._convert_payloads( [job.result.completed.result], ret_types, + payload_converter, ) ret = ret_vals[0] handle._resolve_success(ret) elif job.result.HasField("failed"): handle._resolve_failure( - self._failure_converter.from_failure( - job.result.failed.failure, self._payload_converter + failure_converter.from_failure( + job.result.failed.failure, payload_converter ) ) elif job.result.HasField("cancelled"): handle._resolve_failure( - self._failure_converter.from_failure( - job.result.cancelled.failure, self._payload_converter + failure_converter.from_failure( + job.result.cancelled.failure, payload_converter ) ) elif job.result.HasField("backoff"): @@ -790,6 +837,7 @@ def _apply_resolve_child_workflow_execution( raise RuntimeError( f"Failed finding child workflow handle for sequence {job.seq}" ) + if job.result.HasField("completed"): ret: Optional[Any] = None if job.result.completed.HasField("result"): @@ -797,19 +845,20 @@ def _apply_resolve_child_workflow_execution( ret_vals = self._convert_payloads( [job.result.completed.result], ret_types, + handle._payload_converter, ) ret = ret_vals[0] handle._resolve_success(ret) elif job.result.HasField("failed"): handle._resolve_failure( - self._failure_converter.from_failure( - job.result.failed.failure, self._payload_converter + handle._failure_converter.from_failure( + job.result.failed.failure, handle._payload_converter ) ) elif job.result.HasField("cancelled"): handle._resolve_failure( - self._failure_converter.from_failure( - job.result.cancelled.failure, self._payload_converter + handle._failure_converter.from_failure( + job.result.cancelled.failure, handle._payload_converter ) ) else: @@ -846,8 +895,8 @@ def _apply_resolve_child_workflow_execution_start( elif job.HasField("cancelled"): self._pending_child_workflows.pop(job.seq) handle._resolve_failure( - self._failure_converter.from_failure( - job.cancelled.failure, self._payload_converter + handle._failure_converter.from_failure( + job.cancelled.failure, handle._payload_converter ) ) else: @@ -874,8 +923,8 @@ def _apply_resolve_nexus_operation_start( # The nexus operation start failed; no ResolveNexusOperation will follow. self._pending_nexus_operations.pop(job.seq, None) handle._resolve_failure( - self._failure_converter.from_failure( - job.failed, self._payload_converter + handle._failure_converter.from_failure( + job.failed, handle._payload_converter ) ) else: @@ -905,24 +954,25 @@ def _apply_resolve_nexus_operation( [output] = self._convert_payloads( [result.completed], [handle._input.output_type] if handle._input.output_type else None, + handle._payload_converter, ) handle._resolve_success(output) elif result.HasField("failed"): handle._resolve_failure( - self._failure_converter.from_failure( - result.failed, self._payload_converter + handle._failure_converter.from_failure( + result.failed, handle._payload_converter ) ) elif result.HasField("cancelled"): handle._resolve_failure( - self._failure_converter.from_failure( - result.cancelled, self._payload_converter + handle._failure_converter.from_failure( + result.cancelled, handle._payload_converter ) ) elif result.HasField("timed_out"): handle._resolve_failure( - self._failure_converter.from_failure( - result.timed_out, self._payload_converter + handle._failure_converter.from_failure( + result.timed_out, handle._payload_converter ) ) else: @@ -932,17 +982,22 @@ def _apply_resolve_request_cancel_external_workflow( self, job: temporalio.bridge.proto.workflow_activation.ResolveRequestCancelExternalWorkflow, ) -> None: - fut = self._pending_external_cancels.pop(job.seq, None) - if not fut: + pending = self._pending_external_cancels.pop(job.seq, None) + if not pending: raise RuntimeError( f"Failed finding pending external cancel for sequence {job.seq}" ) + fut, external_workflow_id = pending # We intentionally let this error if future is already done if job.HasField("failure"): + workflow_context = temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=external_workflow_id, + ) + payload_converter = self._payload_converter_with_context(workflow_context) + failure_converter = self._failure_converter_with_context(workflow_context) fut.set_exception( - self._failure_converter.from_failure( - job.failure, self._payload_converter - ) + failure_converter.from_failure(job.failure, payload_converter) ) else: fut.set_result(None) @@ -951,17 +1006,22 @@ def _apply_resolve_signal_external_workflow( self, job: temporalio.bridge.proto.workflow_activation.ResolveSignalExternalWorkflow, ) -> None: - fut = self._pending_external_signals.pop(job.seq, None) - if not fut: + pending = self._pending_external_signals.pop(job.seq, None) + if not pending: raise RuntimeError( f"Failed finding pending external signal for sequence {job.seq}" ) + fut, external_workflow_id = pending # We intentionally let this error if future is already done if job.HasField("failure"): + workflow_context = temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=external_workflow_id, + ) + payload_converter = self._payload_converter_with_context(workflow_context) + failure_converter = self._failure_converter_with_context(workflow_context) fut.set_exception( - self._failure_converter.from_failure( - job.failure, self._payload_converter - ) + failure_converter.from_failure(job.failure, payload_converter) ) else: fut.set_result(None) @@ -984,7 +1044,9 @@ def _apply_initialize_workflow( async def run_workflow(input: ExecuteWorkflowInput) -> None: try: result = await self._inbound.execute_workflow(input) - result_payloads = self._payload_converter.to_payloads([result]) + result_payloads = self._workflow_context_payload_converter.to_payloads( + [result] + ) if len(result_payloads) != 1: raise ValueError( f"Expected 1 result payload, got {len(result_payloads)}" @@ -1022,7 +1084,10 @@ def _make_workflow_input( if not self._defn.name: # Dynamic is just the raw value for each input value arg_types = [temporalio.common.RawValue] * len(init_job.arguments) - args = self._convert_payloads(init_job.arguments, arg_types) + + args = self._convert_payloads( + init_job.arguments, arg_types, self._workflow_context_payload_converter + ) # Put args in a list if dynamic if not self._defn.name: args = [args] @@ -1161,7 +1226,7 @@ def workflow_is_replaying(self) -> bool: def workflow_memo(self) -> Mapping[str, Any]: if self._untyped_converted_memo is None: self._untyped_converted_memo = { - k: self._payload_converter.from_payload(v) + k: self._workflow_context_payload_converter.from_payload(v) for k, v in self._info.raw_memo.items() } return self._untyped_converted_memo @@ -1174,7 +1239,7 @@ def workflow_memo_value( if default is temporalio.common._arg_unset: raise KeyError(f"Memo does not have a value for key {key}") return default - return self._payload_converter.from_payload( + return self._workflow_context_payload_converter.from_payload( payload, type_hint, # type: ignore[arg-type] ) @@ -1188,7 +1253,9 @@ def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: # Intentionally not checking if memo exists, so that no-op removals show up in history too. removals.append(k) else: - update_payloads[k] = self._payload_converter.to_payload(v) + update_payloads[k] = ( + self._workflow_context_payload_converter.to_payload(v) + ) if not update_payloads and not removals: return @@ -1207,7 +1274,7 @@ def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: mut_raw_memo[k] = v if removals: - null_payload = self._payload_converter.to_payload(None) + null_payload = self._workflow_context_payload_converter.to_payload(None) for k in removals: fields[k].CopyFrom(null_payload) mut_raw_memo.pop(k, None) @@ -1215,8 +1282,8 @@ def workflow_upsert_memo(self, updates: Mapping[str, Any]) -> None: # Keeping deserialized memo dict in sync, if exists if self._untyped_converted_memo is not None: for k, v in update_payloads.items(): - self._untyped_converted_memo[k] = self._payload_converter.from_payload( - v + self._untyped_converted_memo[k] = ( + self._workflow_context_payload_converter.from_payload(v) ) for k in removals: self._untyped_converted_memo.pop(k, None) @@ -1254,7 +1321,7 @@ def workflow_patch(self, id: str, *, deprecated: bool) -> bool: return use_patch def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter: - return self._payload_converter + return self._workflow_context_payload_converter def workflow_random(self) -> random.Random: self._assert_not_read_only("random") @@ -1649,7 +1716,7 @@ async def workflow_sleep( ) -> None: user_metadata = ( temporalio.api.sdk.v1.UserMetadata( - summary=self._payload_converter.to_payload(summary) + summary=self._workflow_context_payload_converter.to_payload(summary) ) if summary else None @@ -1674,7 +1741,9 @@ async def workflow_wait_condition( self._conditions.append((fn, fut)) user_metadata = ( temporalio.api.sdk.v1.UserMetadata( - summary=self._payload_converter.to_payload(timeout_summary) + summary=self._workflow_context_payload_converter.to_payload( + timeout_summary + ) ) if timeout_summary else None @@ -1726,18 +1795,18 @@ def workflow_last_completion_result( return None if type_hint is None: - return self._payload_converter.from_payload( + return self._workflow_context_payload_converter.from_payload( self._last_completion_result.payloads[0] ) else: - return self._payload_converter.from_payload( + return self._workflow_context_payload_converter.from_payload( self._last_completion_result.payloads[0], type_hint ) def workflow_last_failure(self) -> Optional[BaseException]: if self._last_failure: - return self._failure_converter.from_failure( - self._last_failure, self._payload_converter + return self._workflow_context_failure_converter.from_failure( + self._last_failure, self._workflow_context_payload_converter ) return None @@ -1806,9 +1875,13 @@ async def run_activity() -> Any: async def _outbound_signal_child_workflow( self, input: SignalChildWorkflowInput ) -> None: - payloads = ( - self._payload_converter.to_payloads(input.args) if input.args else None + payload_converter = self._payload_converter_with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=input.child_workflow_id, + ) ) + payloads = payload_converter.to_payloads(input.args) if input.args else None command = self._add_command() v = command.signal_external_workflow_execution v.child_workflow_id = input.child_workflow_id @@ -1822,9 +1895,13 @@ async def _outbound_signal_child_workflow( async def _outbound_signal_external_workflow( self, input: SignalExternalWorkflowInput ) -> None: - payloads = ( - self._payload_converter.to_payloads(input.args) if input.args else None + payload_converter = self._payload_converter_with_context( + temporalio.converter.WorkflowSerializationContext( + namespace=input.namespace, + workflow_id=input.workflow_id, + ) ) + payloads = payload_converter.to_payloads(input.args) if input.args else None command = self._add_command() v = command.signal_external_workflow_execution v.workflow_execution.namespace = input.namespace @@ -1858,7 +1935,10 @@ def apply_child_cancel_error() -> None: # TODO(cretz): Nothing waits on this future, so how # if at all should we report child-workflow cancel # request failure? - self._pending_external_cancels[cancel_seq] = self.create_future() + self._pending_external_cancels[cancel_seq] = ( + self.create_future(), + input.id, + ) # Function that runs in the handle async def run_child() -> Any: @@ -1964,8 +2044,9 @@ async def _cancel_external_workflow( done_fut = self.create_future() command.request_cancel_external_workflow_execution.seq = seq - # Set as pending - self._pending_external_cancels[seq] = done_fut + # Set as pending with the target workflow ID for later context use + target_workflow_id = command.request_cancel_external_workflow_execution.workflow_execution.workflow_id + self._pending_external_cancels[seq] = (done_fut, target_workflow_id) # Wait until done (there is no cancelling a cancel request) await done_fut @@ -1980,6 +2061,7 @@ def _convert_payloads( self, payloads: Sequence[temporalio.api.common.v1.Payload], types: Optional[List[Type]], + payload_converter: temporalio.converter.PayloadConverter, ) -> List[Any]: if not payloads: return [] @@ -1987,10 +2069,7 @@ def _convert_payloads( if types and len(types) != len(payloads): types = None try: - return self._payload_converter.from_payloads( - payloads, - type_hints=types, - ) + return payload_converter.from_payloads(payloads, type_hints=types) except temporalio.exceptions.FailureError: # Don't wrap payload conversion errors that would fail the workflow raise @@ -1999,6 +2078,110 @@ def _convert_payloads( raise raise RuntimeError("Failed decoding arguments") from err + def _payload_converter_with_context( + self, + context: temporalio.converter.SerializationContext, + ) -> temporalio.converter.PayloadConverter: + """Construct workflow payload converter with the given context. + + This plays a similar role to DataConverter._with_context, but operates on PayloadConverter + only (payload encoding/decoding is done by the worker, outside the workflow sandbox). + """ + payload_converter = self._context_free_payload_converter + if isinstance(payload_converter, temporalio.converter.WithSerializationContext): + payload_converter = payload_converter.with_context(context) + return payload_converter + + def _failure_converter_with_context( + self, + context: temporalio.converter.SerializationContext, + ) -> temporalio.converter.FailureConverter: + """Construct workflow failure converter with the given context. + + This plays a similar role to DataConverter._with_context, but operates on FailureConverter + only (payload encoding/decoding is done by the worker, outside the workflow sandbox). + """ + failure_converter = self._context_free_failure_converter + if isinstance(failure_converter, temporalio.converter.WithSerializationContext): + failure_converter = failure_converter.with_context(context) + return failure_converter + + def get_serialization_context( + self, + command_info: Optional[_command_aware_visitor.CommandInfo], + ) -> Optional[temporalio.converter.SerializationContext]: + if command_info is None: + # Use payload codec with workflow context by default (i.e. for payloads not associated + # with a pending command) + return temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=self._info.workflow_id, + ) + + if ( + command_info.command_type + == temporalio.api.enums.v1.command_type_pb2.CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK + and command_info.command_seq in self._pending_activities + ): + # Use the activity's context + activity_handle = self._pending_activities[command_info.command_seq] + return temporalio.converter.ActivitySerializationContext( + namespace=self._info.namespace, + workflow_id=self._info.workflow_id, + workflow_type=self._info.workflow_type, + activity_type=activity_handle._input.activity, + activity_task_queue=( + activity_handle._input.task_queue + if isinstance(activity_handle._input, StartActivityInput) + and activity_handle._input.task_queue + else self._info.task_queue + ), + is_local=isinstance(activity_handle._input, StartLocalActivityInput), + ) + + elif ( + command_info.command_type + == temporalio.api.enums.v1.command_type_pb2.CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_child_workflows + ): + # Use the child workflow's context + child_wf_handle = self._pending_child_workflows[command_info.command_seq] + return temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=child_wf_handle._input.id, + ) + + elif ( + command_info.command_type + == temporalio.api.enums.v1.command_type_pb2.CommandType.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_external_signals + ): + # Use the target workflow's context + _, target_workflow_id = self._pending_external_signals[ + command_info.command_seq + ] + return temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=target_workflow_id, + ) + + elif ( + command_info.command_type + == temporalio.api.enums.v1.command_type_pb2.CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + and command_info.command_seq in self._pending_nexus_operations + ): + # Use empty context for nexus operations: users will never want to encrypt using a + # key derived from caller workflow context because the caller workflow context is + # not available on the handler side for decryption. + return None + + else: + # Use payload codec with workflow context for all other payloads + return temporalio.converter.WorkflowSerializationContext( + namespace=self._info.namespace, + workflow_id=self._info.workflow_id, + ) + def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") @@ -2082,15 +2265,21 @@ def _process_handler_args( if not defn_name and defn_dynamic_vararg: # Take off the string type hint for conversion arg_types = defn_arg_types[1:] if defn_arg_types else None - return [job_name] + self._convert_payloads(job_input, arg_types) + return [job_name] + self._convert_payloads( + job_input, arg_types, self._workflow_context_payload_converter + ) if not defn_name: return [ job_name, self._convert_payloads( - job_input, [temporalio.common.RawValue] * len(job_input) + job_input, + [temporalio.common.RawValue] * len(job_input), + self._workflow_context_payload_converter, ), ] - return self._convert_payloads(job_input, defn_arg_types) + return self._convert_payloads( + job_input, defn_arg_types, self._workflow_context_payload_converter + ) def _process_signal_job( self, @@ -2243,7 +2432,9 @@ def _set_workflow_failure(self, err: BaseException) -> None: failure = self._add_command().fail_workflow_execution.failure failure.SetInParent() try: - self._failure_converter.to_failure(err, self._payload_converter, failure) + self._workflow_context_failure_converter.to_failure( + err, self._workflow_context_payload_converter, failure + ) except Exception as inner_err: raise ValueError("Failed converting workflow exception") from inner_err @@ -2256,8 +2447,11 @@ async def _signal_external_workflow( done_fut = self.create_future() command.signal_external_workflow_execution.seq = seq - # Set as pending - self._pending_external_signals[seq] = done_fut + target_workflow_id = ( + command.signal_external_workflow_execution.child_workflow_id + or command.signal_external_workflow_execution.workflow_execution.workflow_id + ) + self._pending_external_signals[seq] = (done_fut, target_workflow_id) # Wait until completed or cancelled while True: @@ -2724,6 +2918,20 @@ def __init__( self._result_fut = instance.create_future() self._started = False instance._register_task(self, name=f"activity: {input.activity}") + self._payload_converter = self._instance._payload_converter_with_context( + temporalio.converter.ActivitySerializationContext( + namespace=self._instance._info.namespace, + workflow_id=self._instance._info.workflow_id, + workflow_type=self._instance._info.workflow_type, + activity_type=self._input.activity, + activity_task_queue=( + self._input.task_queue or self._instance._info.task_queue + if isinstance(self._input, StartActivityInput) + else self._instance._info.task_queue + ), + is_local=isinstance(self._input, StartLocalActivityInput), + ) + ) def cancel(self, msg: Optional[Any] = None) -> bool: # Allow the cancel to go through for the task even if we're deleting, @@ -2771,7 +2979,7 @@ def _apply_schedule_command( ) -> None: # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + self._payload_converter.to_payloads(self._input.args) if self._input.args else None ) @@ -2807,7 +3015,7 @@ def _apply_schedule_command( self._input.retry_policy.apply_to_proto(v.retry_policy) if self._input.summary: command.user_metadata.summary.CopyFrom( - self._instance._payload_converter.to_payload(self._input.summary) + self._payload_converter.to_payload(self._input.summary) ) v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, @@ -2871,6 +3079,16 @@ def __init__( self._result_fut: asyncio.Future[Any] = instance.create_future() self._first_execution_run_id = "" instance._register_task(self, name=f"child: {input.workflow}") + workflow_context = temporalio.converter.WorkflowSerializationContext( + namespace=self._instance._info.namespace, + workflow_id=self._input.id, + ) + self._payload_converter = self._instance._payload_converter_with_context( + workflow_context + ) + self._failure_converter = self._instance._failure_converter_with_context( + workflow_context + ) @property def id(self) -> str: @@ -2921,7 +3139,7 @@ def _resolve_failure(self, err: BaseException) -> None: def _apply_start_command(self) -> None: # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + self._payload_converter.to_payloads(self._input.args) if self._input.args else None ) @@ -2956,9 +3174,7 @@ def _apply_start_command(self) -> None: temporalio.common._apply_headers(self._input.headers, v.headers) if self._input.memo: for k, val in self._input.memo.items(): - v.memo[k].CopyFrom( - self._instance._payload_converter.to_payloads([val])[0] - ) + v.memo[k].CopyFrom(self._payload_converter.to_payloads([val])[0]) if self._input.search_attributes: _encode_search_attributes( self._input.search_attributes, v.search_attributes @@ -2971,11 +3187,11 @@ def _apply_start_command(self) -> None: v.versioning_intent = self._input.versioning_intent._to_proto() if self._input.static_summary: command.user_metadata.summary.CopyFrom( - self._instance._payload_converter.to_payload(self._input.static_summary) + self._payload_converter.to_payload(self._input.static_summary) ) if self._input.static_details: command.user_metadata.details.CopyFrom( - self._instance._payload_converter.to_payload(self._input.static_details) + self._payload_converter.to_payload(self._input.static_details) ) if self._input.priority: v.priority.CopyFrom(self._input.priority._to_proto()) @@ -3057,6 +3273,8 @@ def __init__( self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[Optional[str]] = instance.create_future() self._result_fut: asyncio.Future[Optional[OutputT]] = instance.create_future() + self._payload_converter = self._instance._context_free_payload_converter + self._failure_converter = self._instance._context_free_failure_converter @property def operation_token(self) -> Optional[str]: @@ -3089,7 +3307,7 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_schedule_command(self) -> None: - payload = self._instance._payload_converter.to_payload(self._input.input) + payload = self._payload_converter.to_payload(self._input.input) command = self._instance._add_command() v = command.schedule_nexus_operation v.seq = self._seq @@ -3128,13 +3346,17 @@ def __init__( def _apply_command(self) -> None: # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + self._instance._workflow_context_payload_converter.to_payloads( + self._input.args + ) if self._input.args else None ) memo_payloads = ( { - k: self._instance._payload_converter.to_payloads([val])[0] + k: self._instance._workflow_context_payload_converter.to_payloads( + [val] + )[0] for k, val in self._input.memo.items() } if self._input.memo diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index 3091cef1d..17a5c5742 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -6,12 +6,14 @@ import dataclasses import logging -from typing import Any, Type +from typing import Any, Optional, Type import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_completion +import temporalio.converter import temporalio.worker._workflow_instance import temporalio.workflow +from temporalio.worker import _command_aware_visitor logger = logging.getLogger(__name__) @@ -79,3 +81,10 @@ def activate( ) -> temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion: """Send activation to this instance.""" return self.instance.activate(act) + + def get_serialization_context( + self, + command_info: Optional[_command_aware_visitor.CommandInfo], + ) -> Optional[temporalio.converter.SerializationContext]: + """Get serialization context.""" + return self.instance.get_serialization_context(command_info) diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index c656e3041..e1a48871d 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -15,8 +15,8 @@ import temporalio.bridge.proto.workflow_completion import temporalio.common import temporalio.converter -import temporalio.worker._workflow_instance import temporalio.workflow +from temporalio.worker import _command_aware_visitor from ...api.common.v1.message_pb2 import Payloads from ...api.failure.v1.message_pb2 import Failure @@ -185,3 +185,20 @@ def _run_code(self, code: str, **extra_globals: Any) -> None: def get_thread_id(self) -> Optional[int]: return self._current_thread_id + + def get_serialization_context( + self, + command_info: Optional[_command_aware_visitor.CommandInfo], + ) -> Optional[temporalio.converter.SerializationContext]: + # Forward call to the sandboxed instance + self.importer.restriction_context.is_runtime = True + try: + self._run_code( + "with __temporal_importer.applied():\n" + " __temporal_context = __temporal_in_sandbox.get_serialization_context(__temporal_command_info)\n", + __temporal_importer=self.importer, + __temporal_command_info=command_info, + ) + return self.globals_and_locals.pop("__temporal_context", None) # type: ignore + finally: + self.importer.restriction_context.is_runtime = False diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 98b45e367..e5bb25ba6 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -61,7 +61,6 @@ import temporalio.workflow from temporalio.nexus._util import ServiceHandlerT -from .api.failure.v1.message_pb2 import Failure from .types import ( AnyType, CallableAsyncNoParam, @@ -1148,6 +1147,7 @@ def patched(id: str) -> bool: def payload_converter() -> temporalio.converter.PayloadConverter: """Get the payload converter for the current workflow. + The returned converter has :py:class:`temporalio.converter.WorkflowSerializationContext` set. This is often used for dynamic workflows/signals/queries to convert payloads. """ diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py new file mode 100644 index 000000000..ee7be8684 --- /dev/null +++ b/tests/test_serialization_context.py @@ -0,0 +1,1900 @@ +""" +Test context-aware serde/codec operations. + +Serialization context should be available on all serde/codec operations, but testing all of them is +infeasible; this test suite only covers a selection. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import json +import uuid +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, List, Literal, Optional, Sequence, Type + +import nexusrpc +import pytest +from pydantic import BaseModel +from typing_extensions import Never + +import temporalio.api.common.v1 +import temporalio.api.failure.v1 +from temporalio import activity, workflow +from temporalio.client import ( + AsyncActivityHandle, + Client, + WorkflowFailureError, + WorkflowUpdateFailedError, +) +from temporalio.common import RetryPolicy +from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter +from temporalio.converter import ( + ActivitySerializationContext, + CompositePayloadConverter, + DataConverter, + DefaultFailureConverter, + DefaultPayloadConverter, + EncodingPayloadConverter, + JSONPlainPayloadConverter, + PayloadCodec, + PayloadConverter, + SerializationContext, + WithSerializationContext, + WorkflowSerializationContext, +) +from temporalio.exceptions import ApplicationError +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name + + +@dataclass +class TraceItem: + method: Literal[ + "to_payload", + "from_payload", + "to_failure", + "from_failure", + "encode", + "decode", + ] + context: dict[str, Any] + + +@dataclass +class TraceData: + items: list[TraceItem] = field(default_factory=list) + + +class SerializationContextPayloadConverter( + EncodingPayloadConverter, WithSerializationContext +): + def __init__(self): + self.context: Optional[SerializationContext] = None + + @property + def encoding(self) -> str: + return "test-serialization-context" + + def with_context( + self, context: Optional[SerializationContext] + ) -> SerializationContextPayloadConverter: + converter = SerializationContextPayloadConverter() + converter.context = context + return converter + + def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: + if not isinstance(value, TraceData): + return None + if isinstance(self.context, WorkflowSerializationContext): + value.items.append( + TraceItem( + method="to_payload", + context=dataclasses.asdict(self.context), + ) + ) + elif isinstance(self.context, ActivitySerializationContext): + value.items.append( + TraceItem( + method="to_payload", + context=dataclasses.asdict(self.context), + ) + ) + else: + raise Exception(f"Unexpected context type: {type(self.context)}") + payload = JSONPlainPayloadConverter().to_payload(value) + assert payload + payload.metadata["encoding"] = self.encoding.encode() + return payload + + def from_payload( + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, + ) -> Any: + value = JSONPlainPayloadConverter().from_payload(payload, TraceData) + assert isinstance(value, TraceData) + if isinstance(self.context, WorkflowSerializationContext): + value.items.append( + TraceItem( + method="from_payload", + context=dataclasses.asdict(self.context), + ) + ) + elif isinstance(self.context, ActivitySerializationContext): + value.items.append( + TraceItem( + method="from_payload", + context=dataclasses.asdict(self.context), + ) + ) + else: + raise Exception(f"Unexpected context type: {type(self.context)}") + return value + + +class SerializationContextCompositePayloadConverter( + CompositePayloadConverter, WithSerializationContext +): + def __init__(self): + super().__init__( + SerializationContextPayloadConverter(), + *DefaultPayloadConverter.default_encoding_payload_converters, + ) + + +# Payload conversion tests + +## Misc payload conversion + + +@activity.defn +async def passthrough_activity(input: TraceData) -> TraceData: + activity.payload_converter().to_payload(input) + activity.heartbeat(input) + # Wait for the heartbeat to be processed so that it modifies the data before the activity returns + await asyncio.sleep(0.2) + return input + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, data: TraceData) -> TraceData: + return data + + +@workflow.defn +class PayloadConversionWorkflow: + @workflow.run + async def run(self, data: TraceData) -> TraceData: + workflow.payload_converter().to_payload(data) + data = await workflow.execute_activity( + passthrough_activity, + data, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + ) + data = await workflow.execute_child_workflow( + EchoWorkflow.run, data, id=f"{workflow.info().workflow_id}_child" + ) + return data + + +async def test_payload_conversion_calls_follow_expected_sequence_and_contexts( + client: Client, +): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[PayloadConversionWorkflow, EchoWorkflow], + activities=[passthrough_activity], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + result = await client.execute_workflow( + PayloadConversionWorkflow.run, + TraceData(), + id=workflow_id, + task_queue=task_queue, + ) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + child_workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=f"{workflow_id}_child", + ) + ) + activity_context = dataclasses.asdict( + ActivitySerializationContext( + namespace="default", + workflow_id=workflow_id, + workflow_type=PayloadConversionWorkflow.__name__, + activity_type=passthrough_activity.__name__, + activity_task_queue=task_queue, + is_local=False, + ) + ) + assert result.items == [ + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow input + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow input + ), + TraceItem( + method="to_payload", + context=workflow_context, # workflow payload converter + ), + TraceItem( + method="to_payload", + context=activity_context, # Outbound activity input + ), + TraceItem( + method="from_payload", + context=activity_context, # Inbound activity input + ), + TraceItem( + method="to_payload", + context=activity_context, # activity payload converter + ), + TraceItem( + method="to_payload", + context=activity_context, # Outbound heartbeat + ), + TraceItem( + method="to_payload", + context=activity_context, # Outbound activity result + ), + TraceItem( + method="from_payload", + context=activity_context, # Inbound activity result + ), + TraceItem( + method="to_payload", + context=child_workflow_context, # Outbound child workflow input + ), + TraceItem( + method="from_payload", + context=child_workflow_context, # Inbound child workflow input + ), + TraceItem( + method="to_payload", + context=child_workflow_context, # Outbound child workflow result + ), + TraceItem( + method="from_payload", + context=child_workflow_context, # Inbound child workflow result + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow result + ), + ] + + +## Activity heartbeat payload conversion + + +@activity.defn +async def activity_with_heartbeat_details() -> TraceData: + info = activity.info() + if info.attempt == 1: + data = TraceData() + activity.heartbeat(data) + raise Exception("Intentional error to force retry") + elif info.attempt == 2: + [heartbeat_data] = info.heartbeat_details + assert isinstance(heartbeat_data, TraceData) + return heartbeat_data + else: + raise AssertionError(f"Unexpected attempt number: {info.attempt}") + + +@workflow.defn +class HeartbeatDetailsSerializationContextTestWorkflow: + @workflow.run + async def run(self) -> TraceData: + return await workflow.execute_activity( + activity_with_heartbeat_details, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy( + initial_interval=timedelta(milliseconds=100), + maximum_attempts=2, + ), + ) + + +async def test_heartbeat_details_payload_conversion(client: Client): + """Test that heartbeat details are decoded with activity context.""" + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[HeartbeatDetailsSerializationContextTestWorkflow], + activities=[activity_with_heartbeat_details], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + result = await client.execute_workflow( + HeartbeatDetailsSerializationContextTestWorkflow.run, + id=workflow_id, + task_queue=task_queue, + ) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + + activity_context = dataclasses.asdict( + ActivitySerializationContext( + namespace="default", + workflow_id=workflow_id, + workflow_type=HeartbeatDetailsSerializationContextTestWorkflow.__name__, + activity_type=activity_with_heartbeat_details.__name__, + activity_task_queue=task_queue, + is_local=False, + ) + ) + + assert result.items == [ + TraceItem( + method="to_payload", + context=activity_context, # Outbound heartbeat + ), + TraceItem( + method="from_payload", + context=activity_context, # Inbound heartbeart detail + ), + TraceItem( + method="to_payload", + context=activity_context, # Outbound activity result + ), + TraceItem( + method="from_payload", + context=activity_context, # Inbound activity result + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow result + ), + ] + + +## Local activity payload conversion + + +@activity.defn +async def local_activity(input: TraceData) -> TraceData: + return input + + +@workflow.defn +class LocalActivityWorkflow: + @workflow.run + async def run(self, data: TraceData) -> TraceData: + return await workflow.execute_local_activity( + local_activity, + data, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_local_activity_payload_conversion(client: Client): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[LocalActivityWorkflow], + activities=[local_activity], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + result = await client.execute_workflow( + LocalActivityWorkflow.run, + TraceData(), + id=workflow_id, + task_queue=task_queue, + ) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + local_activity_context = dataclasses.asdict( + ActivitySerializationContext( + namespace="default", + workflow_id=workflow_id, + workflow_type=LocalActivityWorkflow.__name__, + activity_type=local_activity.__name__, + activity_task_queue=task_queue, + is_local=True, + ) + ) + + assert result.items == [ + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow input + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow input + ), + TraceItem( + method="to_payload", + context=local_activity_context, # Outbound local activity input + ), + TraceItem( + method="from_payload", + context=local_activity_context, # Inbound local activity input + ), + TraceItem( + method="to_payload", + context=local_activity_context, # Outbound local activity result + ), + TraceItem( + method="from_payload", + context=local_activity_context, # Inbound local activity result + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow result + ), + ] + + +## Async activity completion payload conversion + + +@workflow.defn +class EventWorkflow: + # Like a global asyncio.Event() + + def __init__(self) -> None: + self.signal_received = asyncio.Event() + + @workflow.run + async def run(self) -> None: + await self.signal_received.wait() + + @workflow.signal + def signal(self) -> None: + self.signal_received.set() + + +@activity.defn +async def async_activity() -> TraceData: + await ( + activity.client() + .get_workflow_handle("activity-started-wf-id") + .signal(EventWorkflow.signal) + ) + activity.raise_complete_async() + + +@workflow.defn +class AsyncActivityCompletionSerializationContextTestWorkflow: + @workflow.run + async def run(self) -> TraceData: + return await workflow.execute_activity( + async_activity, + start_to_close_timeout=timedelta(seconds=10), + activity_id="async-activity-id", + ) + + +async def test_async_activity_completion_payload_conversion( + client: Client, +): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ + AsyncActivityCompletionSerializationContextTestWorkflow, + EventWorkflow, + ], + activities=[async_activity], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + workflow_context = WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + activity_context = ActivitySerializationContext( + namespace="default", + workflow_id=workflow_id, + workflow_type=AsyncActivityCompletionSerializationContextTestWorkflow.__name__, + activity_type=async_activity.__name__, + activity_task_queue=task_queue, + is_local=False, + ) + + act_started_wf_handle = await client.start_workflow( + EventWorkflow.run, + id="activity-started-wf-id", + task_queue=task_queue, + ) + wf_handle = await client.start_workflow( + AsyncActivityCompletionSerializationContextTestWorkflow.run, + id=workflow_id, + task_queue=task_queue, + ) + activity_handle = client.get_async_activity_handle( + workflow_id=workflow_id, + run_id=wf_handle.first_execution_run_id, + activity_id="async-activity-id", + ).with_context(activity_context) + + await act_started_wf_handle.result() + data = TraceData() + await activity_handle.heartbeat(data) + await activity_handle.complete(data) + result = await wf_handle.result() + + activity_context_dict = dataclasses.asdict(activity_context) + workflow_context_dict = dataclasses.asdict(workflow_context) + + assert result.items == [ + TraceItem( + method="to_payload", + context=activity_context_dict, # Outbound activity heartbeat + ), + TraceItem( + method="to_payload", + context=activity_context_dict, # Outbound activity completion + ), + TraceItem( + method="from_payload", + context=activity_context_dict, # Inbound activity result + ), + TraceItem( + method="to_payload", + context=workflow_context_dict, # Outbound workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context_dict, # Inbound workflow result + ), + ] + + +class MyAsyncActivityHandle(AsyncActivityHandle): + def my_method(self) -> None: + pass + + +class MyAsyncActivityHandleWithOverriddenConstructor(AsyncActivityHandle): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def my_method(self) -> None: + pass + + +def test_subclassed_async_activity_handle(client: Client): + activity_context = ActivitySerializationContext( + namespace="default", + workflow_id="workflow-id", + workflow_type="workflow-type", + activity_type="activity-type", + activity_task_queue="activity-task-queue", + is_local=False, + ) + handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token") + # This works because the data converter does not use context so AsyncActivityHandle.with_context + # returns self + assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle) + + # This time the data converter uses context so AsyncActivityHandle.with_context attempts to + # return a new instance of the user's subclass. It works, because they have not overridden the + # constructor. + client_config = client.config() + client_config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + client = Client(**client_config) + handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token") + assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle) + + # Finally, a user attempts the same but having overridden the constructor. This fails: + # AsyncActivityHandle.with_context refuses to attempt to create an instance of their subclass. + handle2 = MyAsyncActivityHandleWithOverriddenConstructor( + client=client, id_or_token=b"task-token" + ) + with pytest.raises( + TypeError, + match="you must override with_context to return an instance of your class", + ): + assert isinstance( + handle2.with_context(activity_context), + MyAsyncActivityHandleWithOverriddenConstructor, + ) + + +# Signal test + + +@workflow.defn(sandboxed=False) # so that we can use isinstance +class SignalSerializationContextTestWorkflow: + def __init__(self) -> None: + self.signal_received: Optional[TraceData] = None + + @workflow.run + async def run(self) -> TraceData: + await workflow.wait_condition(lambda: self.signal_received is not None) + assert self.signal_received is not None + return self.signal_received + + @workflow.signal + async def my_signal(self, data: TraceData) -> None: + self.signal_received = data + + +async def test_signal_payload_conversion( + client: Client, +): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + + custom_client = Client(**config) + + async with Worker( + custom_client, + task_queue=task_queue, + workflows=[SignalSerializationContextTestWorkflow], + activities=[], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + handle = await custom_client.start_workflow( + SignalSerializationContextTestWorkflow.run, + id=workflow_id, + task_queue=task_queue, + ) + await handle.signal( + SignalSerializationContextTestWorkflow.my_signal, + TraceData(), + ) + result = await handle.result() + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + assert result.items == [ + TraceItem( + method="to_payload", + context=workflow_context, # Outbound signal input + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound signal input + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound workflow result + ), + ] + + +# Query test + + +@workflow.defn +class QuerySerializationContextTestWorkflow: + @workflow.run + async def run(self) -> None: + await asyncio.Event().wait() + + @workflow.query + def my_query(self, input: TraceData) -> TraceData: + return input + + +async def test_query_payload_conversion( + client: Client, +): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + custom_client = Client(**config) + + async with Worker( + custom_client, + task_queue=task_queue, + workflows=[QuerySerializationContextTestWorkflow], + activities=[], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + handle = await custom_client.start_workflow( + QuerySerializationContextTestWorkflow.run, + id=workflow_id, + task_queue=task_queue, + ) + result = await handle.query( + QuerySerializationContextTestWorkflow.my_query, TraceData() + ) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + assert result.items == [ + TraceItem( + method="to_payload", + context=workflow_context, # Outbound query input + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound query input + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound query result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound query result + ), + ] + + +# Update test + + +@workflow.defn +class UpdateSerializationContextTestWorkflow: + @workflow.init + def __init__(self, pass_validation: bool) -> None: + self.pass_validation = pass_validation + self.input: Optional[TraceData] = None + + @workflow.run + async def run(self, pass_validation: bool) -> TraceData: + await workflow.wait_condition(lambda: self.input is not None) + assert self.input + return self.input + + @workflow.update + def my_update(self, input: TraceData) -> TraceData: + return input + + @my_update.validator + def my_update_validator(self, input: TraceData) -> None: + self.input = input # for test purposes; update validators should not mutate workflow state + if not self.pass_validation: + raise ValueError("Rejected") + + +@pytest.mark.parametrize("pass_validation", [True, False]) +async def test_update_payload_conversion( + client: Client, + pass_validation: bool, +): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + custom_client = Client(**config) + + async with Worker( + custom_client, + task_queue=task_queue, + workflows=[UpdateSerializationContextTestWorkflow], + activities=[], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + wf_handle = await custom_client.start_workflow( + UpdateSerializationContextTestWorkflow.run, + pass_validation, + id=workflow_id, + task_queue=task_queue, + ) + if pass_validation: + result = await wf_handle.execute_update( + UpdateSerializationContextTestWorkflow.my_update, TraceData() + ) + else: + try: + await wf_handle.execute_update( + UpdateSerializationContextTestWorkflow.my_update, TraceData() + ) + raise AssertionError("Expected WorkflowUpdateFailedError") + except WorkflowUpdateFailedError: + pass + + result = await wf_handle.result() + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + assert result.items == [ + TraceItem( + method="to_payload", + context=workflow_context, # Outbound update input + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound update input + ), + TraceItem( + method="to_payload", + context=workflow_context, # Outbound update/workflow result + ), + TraceItem( + method="from_payload", + context=workflow_context, # Inbound update/workflow result + ), + ] + + +# External workflow test + + +@workflow.defn +class ExternalWorkflowTarget: + def __init__(self) -> None: + self.signal_received: Optional[TraceData] = None + + @workflow.run + async def run(self) -> TraceData: + try: + await workflow.wait_condition(lambda: self.signal_received is not None) + return self.signal_received or TraceData() + except asyncio.CancelledError: + return TraceData() + + @workflow.signal + async def external_signal(self, data: TraceData) -> None: + self.signal_received = data + + +@workflow.defn +class ExternalWorkflowSignaler: + @workflow.run + async def run(self, target_id: str, data: TraceData) -> TraceData: + handle = workflow.get_external_workflow_handle(target_id) + await handle.signal(ExternalWorkflowTarget.external_signal, data) + return data + + +@workflow.defn +class ExternalWorkflowCanceller: + @workflow.run + async def run(self, target_id: str) -> TraceData: + handle = workflow.get_external_workflow_handle(target_id) + await handle.cancel() + return TraceData() + + +@pytest.mark.timeout(10) +async def test_external_workflow_signal_and_cancel_payload_conversion( + client: Client, +): + target_workflow_id = str(uuid.uuid4()) + signaler_workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=SerializationContextCompositePayloadConverter, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ + ExternalWorkflowTarget, + ExternalWorkflowSignaler, + ExternalWorkflowCanceller, + ], + activities=[], + workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance + ): + target_handle = await client.start_workflow( + ExternalWorkflowTarget.run, + id=target_workflow_id, + task_queue=task_queue, + ) + + signaler_handle = await client.start_workflow( + ExternalWorkflowSignaler.run, + args=[target_workflow_id, TraceData()], + id=signaler_workflow_id, + task_queue=task_queue, + ) + + signaler_result = await signaler_handle.result() + await target_handle.result() + + signaler_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=signaler_workflow_id, + ) + ) + target_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=target_workflow_id, + ) + ) + + assert ( + signaler_result.items + == [ + TraceItem( + method="to_payload", + context=signaler_context, # Outbound signaler workflow input + ), + TraceItem( + method="from_payload", + context=signaler_context, # Inbound signaler workflow input + ), + TraceItem( + method="to_payload", + context=target_context, # Should use target workflow's context for external signal + ), + TraceItem( + method="to_payload", + context=signaler_context, # Outbound signaler workflow result + ), + TraceItem( + method="from_payload", + context=signaler_context, # Inbound signaler workflow result + ), + ] + ) + + +# Failure conversion + + +@activity.defn +async def failing_activity() -> Never: + raise ApplicationError("test error", dataclasses.asdict(TraceData())) + + +@workflow.defn +class FailureConverterTestWorkflow: + @workflow.run + async def run(self) -> Never: + await workflow.execute_activity( + failing_activity, + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + raise Exception("Unreachable") + + +test_traces: dict[str, list[TraceItem]] = defaultdict(list) + + +class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext): + def __init__(self): + super().__init__(encode_common_attributes=False) + self.context: Optional[SerializationContext] = None + + def with_context( + self, context: Optional[SerializationContext] + ) -> FailureConverterWithContext: + converter = FailureConverterWithContext() + converter.context = context + return converter + + def to_failure( + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, + ) -> None: + assert isinstance( + self.context, (WorkflowSerializationContext, ActivitySerializationContext) + ) + test_traces[self.context.workflow_id].append( + TraceItem( + method="to_failure", + context=dataclasses.asdict(self.context), + ) + ) + super().to_failure(exception, payload_converter, failure) + + def from_failure( + self, + failure: temporalio.api.failure.v1.Failure, + payload_converter: PayloadConverter, + ) -> BaseException: + assert isinstance( + self.context, (WorkflowSerializationContext, ActivitySerializationContext) + ) + test_traces[self.context.workflow_id].append( + TraceItem( + method="from_failure", + context=dataclasses.asdict(self.context), + ) + ) + return super().from_failure(failure, payload_converter) + + +async def test_failure_converter_with_context(client: Client): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + data_converter = dataclasses.replace( + DataConverter.default, + failure_converter_class=FailureConverterWithContext, + ) + config = client.config() + config["data_converter"] = data_converter + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[FailureConverterTestWorkflow], + activities=[failing_activity], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + try: + await client.execute_workflow( + FailureConverterTestWorkflow.run, + id=workflow_id, + task_queue=task_queue, + ) + raise AssertionError("unreachable") + except WorkflowFailureError: + pass + + assert isinstance(data_converter.failure_converter, FailureConverterWithContext) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace="default", + workflow_id=workflow_id, + ) + ) + activity_context = dataclasses.asdict( + ActivitySerializationContext( + namespace="default", + workflow_id=workflow_id, + workflow_type=FailureConverterTestWorkflow.__name__, + activity_type=failing_activity.__name__, + activity_task_queue=task_queue, + is_local=False, + ) + ) + assert test_traces[workflow_id] == ( + [ + TraceItem( + context=activity_context, + method="to_failure", # outbound activity result + ) + ] + + ( + [ + TraceItem( + context=activity_context, + method="from_failure", # inbound activity result + ) + ] + * 2 # from_failure deserializes the error and error cause + ) + + [ + TraceItem( + context=workflow_context, + method="to_failure", # outbound workflow result + ) + ] + + ( + [ + TraceItem( + context=workflow_context, + method="from_failure", # inbound workflow result + ) + ] + * 2 # from_failure deserializes the error and error cause + ) + ) + del test_traces[workflow_id] + + +# Test payload codec + + +class PayloadCodecWithContext(PayloadCodec, WithSerializationContext): + def __init__(self): + self.context: Optional[SerializationContext] = None + self.encode_called_with_context = False + self.decode_called_with_context = False + + def with_context( + self, context: Optional[SerializationContext] + ) -> PayloadCodecWithContext: + codec = PayloadCodecWithContext() + codec.context = context + return codec + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> List[temporalio.api.common.v1.Payload]: + assert self.context + if isinstance(self.context, ActivitySerializationContext): + test_traces[self.context.workflow_id].append( + TraceItem( + context=dataclasses.asdict(self.context), + method="encode", + ) + ) + else: + assert isinstance(self.context, WorkflowSerializationContext) + test_traces[self.context.workflow_id].append( + TraceItem( + context=dataclasses.asdict(self.context), + method="encode", + ) + ) + return list(payloads) + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> List[temporalio.api.common.v1.Payload]: + assert self.context + if isinstance(self.context, ActivitySerializationContext): + test_traces[self.context.workflow_id].append( + TraceItem( + context=dataclasses.asdict(self.context), + method="decode", + ) + ) + else: + assert isinstance(self.context, WorkflowSerializationContext) + test_traces[self.context.workflow_id].append( + TraceItem( + context=dataclasses.asdict(self.context), + method="decode", + ) + ) + return list(payloads) + + +@workflow.defn +class CodecTestWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return data + + +async def test_codec_with_context(client: Client): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + client_config = client.config() + client_config["data_converter"] = dataclasses.replace( + DataConverter.default, payload_codec=PayloadCodecWithContext() + ) + client = Client(**client_config) + async with Worker( + client, + task_queue=task_queue, + workflows=[CodecTestWorkflow], + ): + await client.execute_workflow( + CodecTestWorkflow.run, + "data", + id=workflow_id, + task_queue=task_queue, + ) + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=workflow_id, + ) + ) + assert test_traces[workflow_id] == [ + TraceItem( + context=workflow_context, + method="encode", + ), + TraceItem( + context=workflow_context, + method="decode", + ), + TraceItem( + context=workflow_context, + method="encode", + ), + TraceItem( + context=workflow_context, + method="decode", + ), + ] + del test_traces[workflow_id] + + +# Local activity codec test + + +@activity.defn +async def codec_test_local_activity(data: str) -> str: + return data + + +@workflow.defn +class LocalActivityCodecTestWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return await workflow.execute_local_activity( + codec_test_local_activity, + data, + start_to_close_timeout=timedelta(seconds=10), + ) + + +async def test_local_activity_codec_with_context(client: Client): + """Test that codec gets correct context with is_local=True for local activities.""" + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + client_config = client.config() + client_config["data_converter"] = dataclasses.replace( + DataConverter.default, payload_codec=PayloadCodecWithContext() + ) + client = Client(**client_config) + async with Worker( + client, + task_queue=task_queue, + workflows=[LocalActivityCodecTestWorkflow], + activities=[codec_test_local_activity], + ): + await client.execute_workflow( + LocalActivityCodecTestWorkflow.run, + "data", + id=workflow_id, + task_queue=task_queue, + ) + + workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=workflow_id, + ) + ) + local_activity_context = dataclasses.asdict( + ActivitySerializationContext( + namespace=client.namespace, + workflow_id=workflow_id, + workflow_type=LocalActivityCodecTestWorkflow.__name__, + activity_type=codec_test_local_activity.__name__, + activity_task_queue=task_queue, + is_local=True, + ) + ) + + assert test_traces[workflow_id] == [ + TraceItem( + context=workflow_context, + method="encode", # outbound workflow input + ), + TraceItem( + context=workflow_context, + method="decode", # inbound workflow input + ), + TraceItem( + context=local_activity_context, + method="encode", # outbound local activity input + ), + TraceItem( + context=local_activity_context, + method="decode", # inbound local activity input + ), + TraceItem( + context=local_activity_context, + method="encode", # outbound local activity result + ), + TraceItem( + context=local_activity_context, + method="decode", # inbound local activity result + ), + TraceItem( + context=workflow_context, + method="encode", # outbound workflow result + ), + TraceItem( + context=workflow_context, + method="decode", # inbound workflow result + ), + ] + del test_traces[workflow_id] + + +# Child workflow codec test + + +@workflow.defn +class ChildWorkflowCodecTestWorkflow: + @workflow.run + async def run(self, data: TraceData) -> TraceData: + return await workflow.execute_child_workflow( + EchoWorkflow.run, + data, + id=f"{workflow.info().workflow_id}-child", + ) + + +async def test_child_workflow_codec_with_context(client: Client): + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + child_workflow_id = f"{workflow_id}-child" + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_codec=PayloadCodecWithContext(), + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ChildWorkflowCodecTestWorkflow, EchoWorkflow], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + await client.execute_workflow( + ChildWorkflowCodecTestWorkflow.run, + TraceData(), + id=workflow_id, + task_queue=task_queue, + ) + + parent_workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=workflow_id, + ) + ) + child_workflow_context = dataclasses.asdict( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=child_workflow_id, + ) + ) + + assert test_traces[workflow_id] == [ + TraceItem( + context=parent_workflow_context, + method="encode", # outbound workflow input + ), + TraceItem( + context=parent_workflow_context, + method="decode", # inbound workflow input + ), + TraceItem( + context=parent_workflow_context, + method="encode", # outbound workflow result + ), + TraceItem( + context=parent_workflow_context, + method="decode", # inbound workflow result + ), + ] + assert test_traces[child_workflow_id] == [ + TraceItem( + context=child_workflow_context, + method="encode", # outbound child workflow input + ), + TraceItem( + context=child_workflow_context, + method="decode", # inbound child workflow input + ), + TraceItem( + context=child_workflow_context, + method="encode", # outbound child workflow result + ), + TraceItem( + context=child_workflow_context, + method="decode", # inbound child workflow result + ), + ] + del test_traces[workflow_id] + del test_traces[child_workflow_id] + + +# Payload codec: test decode context matches encode context + + +class PayloadEncryptionCodec(PayloadCodec, WithSerializationContext): + """ + The outbound data for encoding must always be the string "outbound". "Encrypt" it by replacing + it with a key that is derived from the context available during encoding. On decryption, assert + that the same key can be derived from the context available during decoding, and return the + string "inbound". + """ + + def __init__(self): + self.context: Optional[SerializationContext] = None + + def with_context( + self, context: Optional[SerializationContext] + ) -> PayloadEncryptionCodec: + codec = PayloadEncryptionCodec() + codec.context = context + return codec + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> List[temporalio.api.common.v1.Payload]: + [payload] = payloads + return [ + temporalio.api.common.v1.Payload( + metadata=payload.metadata, + data=json.dumps(self._get_encryption_key()).encode(), + ) + ] + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> List[temporalio.api.common.v1.Payload]: + [payload] = payloads + assert json.loads(payload.data.decode()) == self._get_encryption_key() + metadata = dict(payload.metadata) + return [temporalio.api.common.v1.Payload(metadata=metadata, data=b'"inbound"')] + + def _get_encryption_key(self) -> str: + context = ( + dataclasses.asdict(self.context) + if isinstance( + self.context, + (WorkflowSerializationContext, ActivitySerializationContext), + ) + else {} + ) + return json.dumps({k: v for k, v in sorted(context.items())}) + + +@activity.defn +async def payload_encryption_activity(data: str) -> str: + assert data == "inbound" + return "outbound" + + +@workflow.defn +class PayloadEncryptionChildWorkflow: + @workflow.run + async def run(self, data: str) -> str: + assert data == "inbound" + return "outbound" + + +@nexusrpc.service +class PayloadEncryptionService: + payload_encryption_operation: nexusrpc.Operation[str, str] + + +@nexusrpc.handler.service_handler +class PayloadEncryptionServiceHandler: + @nexusrpc.handler.sync_operation + async def payload_encryption_operation( + self, _: nexusrpc.handler.StartOperationContext, data: str + ) -> str: + assert data == "inbound" + return "outbound" + + +@workflow.defn +class PayloadEncryptionWorkflow: + def __init__(self): + self.received_signal = False + self.received_update = False + + @workflow.run + async def run(self, data: str) -> str: + await workflow.wait_condition( + lambda: (self.received_signal and self.received_update) + ) + # Run them in parallel to check that data converter operations do not mix up contexts when + # there are multiple concurrent payload types. + coros = [ + workflow.execute_activity( + payload_encryption_activity, + "outbound", + start_to_close_timeout=timedelta(seconds=10), + ), + workflow.execute_child_workflow( + PayloadEncryptionChildWorkflow.run, + "outbound", + id=f"{workflow.info().workflow_id}_child", + ), + ] + [act_result, cw_result], _ = await workflow.wait( + [asyncio.create_task(c) for c in coros] + ) + assert await act_result == "inbound" + assert await cw_result == "inbound" + return "outbound" + + @workflow.query + def query(self, data: str) -> str: + assert data == "inbound" + return "outbound" + + @workflow.signal + def signal(self, data: str) -> None: + assert data == "inbound" + self.received_signal = True + + @workflow.update + def update(self, data: str) -> str: + assert data == "inbound" + self.received_update = True + return "outbound" + + @update.validator + def update_validator(self, data: str) -> None: + assert data == "inbound" + + +async def test_decode_context_matches_encode_context( + client: Client, +): + """ + Encode outbound payloads with a key using all available context fields, in order to demonstrate + that the same context is available to decode inbound payloads. + """ + workflow_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + config = client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_codec=PayloadEncryptionCodec(), + ) + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[PayloadEncryptionWorkflow, PayloadEncryptionChildWorkflow], + activities=[payload_encryption_activity], + nexus_service_handlers=[PayloadEncryptionServiceHandler()], + ): + wf_handle = await client.start_workflow( + PayloadEncryptionWorkflow.run, + "outbound", + id=workflow_id, + task_queue=task_queue, + ) + assert "inbound" == await wf_handle.query( + PayloadEncryptionWorkflow.query, "outbound" + ) + await wf_handle.signal(PayloadEncryptionWorkflow.signal, "outbound") + assert "inbound" == await wf_handle.execute_update( + PayloadEncryptionWorkflow.update, "outbound" + ) + assert "inbound" == await wf_handle.result() + + +# Test nexus payload codec + + +class AssertNexusLacksContextPayloadCodec(PayloadCodec, WithSerializationContext): + def __init__(self): + self.context = None + + def with_context( + self, context: SerializationContext + ) -> AssertNexusLacksContextPayloadCodec: + codec = AssertNexusLacksContextPayloadCodec() + codec.context = context + return codec + + async def _assert_context_iff_not_nexus( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> List[temporalio.api.common.v1.Payload]: + [payload] = payloads + assert bool(self.context) == (payload.data.decode() != '"nexus-data"') + return list(payloads) + + encode = decode = _assert_context_iff_not_nexus + + +@nexusrpc.handler.service_handler +class NexusOperationTestServiceHandler: + @nexusrpc.handler.sync_operation + async def operation( + self, _: nexusrpc.handler.StartOperationContext, data: str + ) -> str: + return data + + +@workflow.defn +class NexusOperationTestWorkflow: + @workflow.run + async def run(self, data: str) -> None: + nexus_client = workflow.create_nexus_client( + service=NexusOperationTestServiceHandler, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + await nexus_client.start_operation( + NexusOperationTestServiceHandler.operation, input="nexus-data" + ) + + +async def test_nexus_payload_codec_operations_lack_context( + env: WorkflowEnvironment, +): + """ + encode() and decode() on nexus payloads should not have any context set. + """ + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + config = env.client.config() + config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_codec=AssertNexusLacksContextPayloadCodec(), + ) + client = Client(**config) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[NexusOperationTestWorkflow], + nexus_service_handlers=[NexusOperationTestServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + await client.execute_workflow( + NexusOperationTestWorkflow.run, + "workflow-data", + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + + +# Test pydantic converter with context + + +class PydanticData(BaseModel): + value: str + trace: List[str] = [] + + +class PydanticJSONConverterWithContext( + PydanticJSONPlainPayloadConverter, WithSerializationContext +): + def __init__(self): + super().__init__() + self.context: Optional[SerializationContext] = None + + def with_context( + self, context: Optional[SerializationContext] + ) -> PydanticJSONConverterWithContext: + converter = PydanticJSONConverterWithContext() + converter.context = context + return converter + + def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: + if isinstance(value, PydanticData) and self.context: + if isinstance(self.context, WorkflowSerializationContext): + value.trace.append(f"wf_{self.context.workflow_id}") + return super().to_payload(value) + + +class PydanticConverterWithContext(CompositePayloadConverter, WithSerializationContext): + def __init__(self): + super().__init__( + *( + c + if not isinstance(c, JSONPlainPayloadConverter) + else PydanticJSONConverterWithContext() + for c in DefaultPayloadConverter.default_encoding_payload_converters + ) + ) + self.context: Optional[SerializationContext] = None + + +@workflow.defn +class PydanticContextWorkflow: + @workflow.run + async def run(self, data: PydanticData) -> PydanticData: + return data + + +async def test_pydantic_converter_with_context(client: Client): + wf_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + client_config = client.config() + client_config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=PydanticConverterWithContext, + ) + client = Client(**client_config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[PydanticContextWorkflow], + ): + result = await client.execute_workflow( + PydanticContextWorkflow.run, + PydanticData(value="test"), + id=wf_id, + task_queue=task_queue, + ) + assert f"wf_{wf_id}" in result.trace + + +# Test customized DefaultPayloadConverter + +# The SDK's CompositePayloadConverter comes with a with_context implementation that ensures that its +# component EncodingPayloadConverters will be replaced with the results of calling with_context() on +# them, if they support with_context (this happens when we call data_converter._with_context). In +# this test, the user has subclassed CompositePayloadConverter. The test confirms that the +# CompositePayloadConverter's with_context yields an instance of the user's subclass. + + +class UserMethodCalledError(Exception): + pass + + +class CustomEncodingPayloadConverter( + JSONPlainPayloadConverter, WithSerializationContext +): + @property + def encoding(self) -> str: + return "custom-encoding-that-does-not-clash-with-default-converters" + + def __init__(self): + super().__init__() + self.context: Optional[SerializationContext] = None + + def with_context( + self, context: Optional[SerializationContext] + ) -> CustomEncodingPayloadConverter: + converter = CustomEncodingPayloadConverter() + converter.context = context + return converter + + +class CustomPayloadConverter(CompositePayloadConverter): + def __init__(self): + # Add a context-aware EncodingPayloadConverter so that + # CompositePayloadConverter.with_context is forced to construct and return a new instance. + super().__init__( + CustomEncodingPayloadConverter(), + *DefaultPayloadConverter.default_encoding_payload_converters, + ) + + def to_payloads( + self, values: Sequence[Any] + ) -> List[temporalio.api.common.v1.Payload]: + raise UserMethodCalledError + + def from_payloads( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, + ) -> List[Any]: + raise NotImplementedError + + +async def test_user_customization_of_default_payload_converter( + client: Client, +): + wf_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + client_config = client.config() + client_config["data_converter"] = dataclasses.replace( + DataConverter.default, + payload_converter_class=CustomPayloadConverter, + ) + client = Client(**client_config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[EchoWorkflow], + ): + with pytest.raises(UserMethodCalledError): + await client.execute_workflow( + EchoWorkflow.run, + TraceData(), + id=wf_id, + task_queue=task_queue, + ) diff --git a/tests/worker/test_command_aware_visitor.py b/tests/worker/test_command_aware_visitor.py new file mode 100644 index 000000000..92c67b218 --- /dev/null +++ b/tests/worker/test_command_aware_visitor.py @@ -0,0 +1,89 @@ +"""Test that CommandAwarePayloadVisitor handles all commands with seq fields that have payloads.""" + +from typing import Any, Iterator, Type + +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge.proto.workflow_activation import workflow_activation_pb2 +from temporalio.bridge.proto.workflow_commands import workflow_commands_pb2 +from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor + + +def test_command_aware_visitor_has_methods_for_all_seq_protos_with_payloads(): + """Verify CommandAwarePayloadVisitor has methods for all protos with seq fields that have payloads. + + We only override methods when the base class has a visitor method (i.e., there are payloads to visit). + Commands without payloads don't need overrides since there's nothing to visit. + """ + visitor = CommandAwarePayloadVisitor() + + # Find all protos with seq + command_protos = list(_get_workflow_command_protos_with_seq()) + job_protos = list(_get_workflow_activation_job_protos_with_seq()) + assert command_protos, "Should find workflow commands with seq" + assert job_protos, "Should find workflow activation jobs with seq" + + # Check workflow commands - only ones with payloads need overrides + commands_missing = [] + commands_with_payloads = [] + for proto_class in command_protos: + method_name = f"_visit_coresdk_workflow_commands_{proto_class.__name__}" + # Only check if base class has this visitor (meaning there are payloads) + if hasattr(PayloadVisitor, method_name): + commands_with_payloads.append(proto_class.__name__) + # Check if CommandAwarePayloadVisitor has its own override (not just inherited) + if method_name not in CommandAwarePayloadVisitor.__dict__: + commands_missing.append(proto_class.__name__) + + # Check workflow activation jobs - only ones with payloads need overrides + jobs_missing = [] + jobs_with_payloads = [] + for proto_class in job_protos: + method_name = f"_visit_coresdk_workflow_activation_{proto_class.__name__}" + # Only check if base class has this visitor (meaning there are payloads) + if hasattr(PayloadVisitor, method_name): + jobs_with_payloads.append(proto_class.__name__) + # Check if CommandAwarePayloadVisitor has its own override (not just inherited) + if method_name not in CommandAwarePayloadVisitor.__dict__: + jobs_missing.append(proto_class.__name__) + + errors = [] + if commands_missing: + errors.append( + f"Missing visitor methods for commands with seq and payloads: {commands_missing}\n" + f"Add methods to CommandAwarePayloadVisitor for these commands." + ) + if jobs_missing: + errors.append( + f"Missing visitor methods for activation jobs with seq and payloads: {jobs_missing}\n" + f"Add methods to CommandAwarePayloadVisitor for these jobs." + ) + + assert not errors, "\n".join(errors) + + # Verify we found the expected commands/jobs with payloads + assert len(commands_with_payloads) > 0, "Should find commands with payloads" + assert len(jobs_with_payloads) > 0, "Should find activation jobs with payloads" + + # Sanity check: we should have fewer overrides than total protos with seq + # (because some don't have payloads) + assert len(commands_with_payloads) < len( + command_protos + ), "Should have some commands without payloads" + # All activation jobs except FireTimer have payloads + assert ( + len(jobs_with_payloads) == len(job_protos) - 1 + ), "Should have exactly one activation job without payloads (FireTimer)" + + +def _get_workflow_command_protos_with_seq() -> Iterator[Type[Any]]: + """Get concrete classes of all workflow command protos with a seq field.""" + for descriptor in workflow_commands_pb2.DESCRIPTOR.message_types_by_name.values(): + if "seq" in descriptor.fields_by_name: + yield descriptor._concrete_class + + +def _get_workflow_activation_job_protos_with_seq() -> Iterator[Type[Any]]: + """Get concrete classes of all workflow activation job protos with a seq field.""" + for descriptor in workflow_activation_pb2.DESCRIPTOR.message_types_by_name.values(): + if "seq" in descriptor.fields_by_name: + yield descriptor._concrete_class diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 9661ad7cc..3ecd6c63b 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -46,6 +46,7 @@ import temporalio.client import temporalio.converter import temporalio.worker +import temporalio.worker._command_aware_visitor import temporalio.workflow from temporalio import activity, workflow from temporalio.api.common.v1 import Payload, Payloads, WorkflowExecution @@ -1610,6 +1611,12 @@ def activate(self, act: WorkflowActivation) -> WorkflowActivationCompletion: self._runner._pairs.append((act, comp)) return comp + def get_serialization_context( + self, + command_info: Optional[temporalio.worker._command_aware_visitor.CommandInfo], + ) -> Optional[temporalio.converter.SerializationContext]: + return self._unsandboxed.get_serialization_context(command_info) + async def test_workflow_with_custom_runner(client: Client): runner = CustomWorkflowRunner()