diff --git a/temporalio/converter.py b/temporalio/converter.py index 29eb35566..769aa2eac 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -93,51 +93,50 @@ class SerializationContext(ABC): @dataclass(frozen=True) -class BaseWorkflowSerializationContext(SerializationContext): - """Base serialization context shared by workflow and activity serialization contexts.""" - - namespace: str - workflow_id: str - - -@dataclass(frozen=True) -class WorkflowSerializationContext(BaseWorkflowSerializationContext): +class WorkflowSerializationContext(SerializationContext): """Serialization context for workflows. See :py:class:`SerializationContext` for more details. - - Attributes: - namespace: The namespace the workflow is running in. - workflow_id: The ID of the workflow. Note that this is the ID of the workflow of which the - payload being operated on is an input or output. Note also that when creating/describing - schedules, this may be the workflow ID prefix as configured, not the final workflow ID - when the workflow is created by the schedule. """ - pass + namespace: str + """Namespace used by the worker executing the workflow.""" + + workflow_id: Optional[str] + """Workflow ID. + + Note that this is the ID of the workflow of which the payload being operated on is an input or + output. When creating/describing schedules, this may be the workflow ID prefix as configured, + not the final workflow ID when the workflow is created by the schedule.""" @dataclass(frozen=True) -class ActivitySerializationContext(BaseWorkflowSerializationContext): +class ActivitySerializationContext(SerializationContext): """Serialization context for activities. See :py:class:`SerializationContext` for more details. - - Attributes: - namespace: Workflow/activity namespace. - workflow_id: Workflow ID. Note, when creating/describing schedules, - this may be the workflow ID prefix as configured, not the final workflow ID when the - workflow is created by the schedule. - workflow_type: Workflow Type. - activity_type: Activity Type. - activity_task_queue: Activity task queue. - is_local: Whether the activity is a local activity. """ - workflow_type: str + namespace: str + """Namespace used by the worker executing the activity.""" + + activity_id: Optional[str] + """Activity ID.""" + activity_type: str + """Activity type.""" + activity_task_queue: str + """Activity task queue.""" + + workflow_id: Optional[str] + """Workflow ID.""" + + workflow_type: Optional[str] + """Workflow type.""" + is_local: bool + """Whether the activity is a local activity.""" class WithSerializationContext(ABC): diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 44bfb6910..313b92193 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -256,11 +256,12 @@ async def _heartbeat_async( if activity.info: context = temporalio.converter.ActivitySerializationContext( namespace=activity.info.workflow_namespace, - workflow_id=activity.info.workflow_id, - workflow_type=activity.info.workflow_type, + activity_id=activity.info.activity_id, activity_type=activity.info.activity_type, activity_task_queue=self._task_queue, is_local=activity.info.is_local, + workflow_id=activity.info.workflow_id, + workflow_type=activity.info.workflow_type, ) data_converter = data_converter.with_context(context) @@ -308,11 +309,12 @@ async def _handle_start_activity_task( # Create serialization context for the activity context = temporalio.converter.ActivitySerializationContext( namespace=start.workflow_namespace, - workflow_id=start.workflow_execution.workflow_id, - workflow_type=start.workflow_type, + activity_id=start.activity_id, activity_type=start.activity_type, activity_task_queue=self._task_queue, is_local=start.is_local, + workflow_id=start.workflow_execution.workflow_id, + workflow_type=start.workflow_type, ) data_converter = self._data_converter.with_context(context) try: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 44eb443ff..b18e1b208 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -788,8 +788,7 @@ def _apply_resolve_activity( raise RuntimeError(f"Failed finding activity handle for sequence {job.seq}") activity_context = temporalio.converter.ActivitySerializationContext( namespace=self._info.namespace, - workflow_id=self._info.workflow_id, - workflow_type=self._info.workflow_type, + activity_id=handle._input.activity_id, activity_type=handle._input.activity, activity_task_queue=( handle._input.task_queue or self._info.task_queue @@ -797,6 +796,8 @@ def _apply_resolve_activity( else self._info.task_queue ), is_local=isinstance(handle._input, StartLocalActivityInput), + workflow_id=self._info.workflow_id, + workflow_type=self._info.workflow_type, ) payload_converter = self._payload_converter_with_context(activity_context) failure_converter = self._failure_converter_with_context(activity_context) @@ -2127,8 +2128,7 @@ def get_serialization_context( activity_handle = self._pending_activities[command_info.command_seq] return temporalio.converter.ActivitySerializationContext( namespace=self._info.namespace, - workflow_id=self._info.workflow_id, - workflow_type=self._info.workflow_type, + activity_id=activity_handle._input.activity_id, activity_type=activity_handle._input.activity, activity_task_queue=( activity_handle._input.task_queue @@ -2137,6 +2137,8 @@ def get_serialization_context( else self._info.task_queue ), is_local=isinstance(activity_handle._input, StartLocalActivityInput), + workflow_id=self._info.workflow_id, + workflow_type=self._info.workflow_type, ) elif ( @@ -2921,8 +2923,7 @@ def __init__( self._payload_converter = self._instance._payload_converter_with_context( temporalio.converter.ActivitySerializationContext( namespace=self._instance._info.namespace, - workflow_id=self._instance._info.workflow_id, - workflow_type=self._instance._info.workflow_type, + activity_id=self._input.activity_id, activity_type=self._input.activity, activity_task_queue=( self._input.task_queue or self._instance._info.task_queue @@ -2930,6 +2931,8 @@ def __init__( else self._instance._info.task_queue ), is_local=isinstance(self._input, StartLocalActivityInput), + workflow_id=self._instance._info.workflow_id, + workflow_type=self._instance._info.workflow_type, ) ) diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index ee7be8684..02a697c71 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -179,6 +179,7 @@ async def run(self, data: TraceData) -> TraceData: data, start_to_close_timeout=timedelta(seconds=10), heartbeat_timeout=timedelta(seconds=2), + activity_id="activity-id", ) data = await workflow.execute_child_workflow( EchoWorkflow.run, data, id=f"{workflow.info().workflow_id}_child" @@ -231,6 +232,7 @@ async def test_payload_conversion_calls_follow_expected_sequence_and_contexts( workflow_id=workflow_id, workflow_type=PayloadConversionWorkflow.__name__, activity_type=passthrough_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -328,6 +330,7 @@ async def run(self) -> TraceData: initial_interval=timedelta(milliseconds=100), maximum_attempts=2, ), + activity_id="activity-id", ) @@ -370,6 +373,7 @@ async def test_heartbeat_details_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=HeartbeatDetailsSerializationContextTestWorkflow.__name__, activity_type=activity_with_heartbeat_details.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -419,6 +423,7 @@ async def run(self, data: TraceData) -> TraceData: local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -459,6 +464,7 @@ async def test_local_activity_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityWorkflow.__name__, activity_type=local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -504,7 +510,7 @@ async def test_local_activity_payload_conversion(client: Client): @workflow.defn -class EventWorkflow: +class WaitForSignalWorkflow: # Like a global asyncio.Event() def __init__(self) -> None: @@ -521,10 +527,11 @@ def signal(self) -> None: @activity.defn async def async_activity() -> TraceData: + # Notify test that the activity has started and is ready to be completed manually await ( activity.client() .get_workflow_handle("activity-started-wf-id") - .signal(EventWorkflow.signal) + .signal(WaitForSignalWorkflow.signal) ) activity.raise_complete_async() @@ -558,7 +565,7 @@ async def test_async_activity_completion_payload_conversion( task_queue=task_queue, workflows=[ AsyncActivityCompletionSerializationContextTestWorkflow, - EventWorkflow, + WaitForSignalWorkflow, ], activities=[async_activity], workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance @@ -572,12 +579,13 @@ async def test_async_activity_completion_payload_conversion( workflow_id=workflow_id, workflow_type=AsyncActivityCompletionSerializationContextTestWorkflow.__name__, activity_type=async_activity.__name__, + activity_id="async-activity-id", activity_task_queue=task_queue, is_local=False, ) act_started_wf_handle = await client.start_workflow( - EventWorkflow.run, + WaitForSignalWorkflow.run, id="activity-started-wf-id", task_queue=task_queue, ) @@ -644,6 +652,7 @@ def test_subclassed_async_activity_handle(client: Client): workflow_id="workflow-id", workflow_type="workflow-type", activity_type="activity-type", + activity_id="activity-id", activity_task_queue="activity-task-queue", is_local=False, ) @@ -1058,11 +1067,12 @@ async def run(self) -> Never: failing_activity, start_to_close_timeout=timedelta(seconds=10), retry_policy=RetryPolicy(maximum_attempts=1), + activity_id="activity-id", ) raise Exception("Unreachable") -test_traces: dict[str, list[TraceItem]] = defaultdict(list) +test_traces: dict[Optional[str], list[TraceItem]] = defaultdict(list) class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext): @@ -1154,6 +1164,7 @@ async def test_failure_converter_with_context(client: Client): workflow_id=workflow_id, workflow_type=FailureConverterTestWorkflow.__name__, activity_type=failing_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -1322,6 +1333,7 @@ async def run(self, data: str) -> str: codec_test_local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -1360,6 +1372,7 @@ async def test_local_activity_codec_with_context(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityCodecTestWorkflow.__name__, activity_type=codec_test_local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -1593,6 +1606,7 @@ async def run(self, data: str) -> str: payload_encryption_activity, "outbound", start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ), workflow.execute_child_workflow( PayloadEncryptionChildWorkflow.run,