diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index e4cb05eee..9613033df 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -384,6 +384,7 @@ async def decode_activation( codec: temporalio.converter.PayloadCodec, decode_headers: bool, ) -> None: + print("Decoding activation") """Decode the given activation with the codec.""" for job in act.jobs: if job.HasField("query_workflow"): @@ -462,6 +463,7 @@ async def encode_completion( codec: temporalio.converter.PayloadCodec, encode_headers: bool, ) -> None: + print("Encoding completion") """Recursively encode the given completion with the codec.""" if comp.HasField("failed"): await codec.encode_failure(comp.failed.failure) diff --git a/temporalio/converter.py b/temporalio/converter.py index 190fda0e6..f00163957 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -12,6 +12,7 @@ import uuid import warnings from abc import ABC, abstractmethod +from copy import copy from dataclasses import dataclass from datetime import datetime from enum import IntEnum @@ -28,11 +29,14 @@ Mapping, NewType, Optional, + Protocol, + Self, Sequence, Tuple, Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -65,6 +69,74 @@ logger = getLogger(__name__) +class SerializationContext(ABC): + """Base serialization context. + + This provides contextual information during serialization and deserialization + operations. Different contexts (activity, workflow, etc.) can provide + specialized information. + """ + + pass + + +@dataclass(frozen=True) +class ActivitySerializationContext(SerializationContext): + """Serialization context for activities. + + Attributes: + activity_id: The ID of the activity. + activity_type: The type/name of the activity. + attempt: The current attempt number (starting from 1). + is_local: Whether this is a local activity. + """ + + namespace: str + workflow_id: str + workflow_type: str + activity_type: str + activity_task_queue: Optional[str] + is_local: bool + + +@dataclass(frozen=True) +class WorkflowSerializationContext(SerializationContext): + """Serialization context for workflows. + + Attributes: + workflow_id: The workflow ID. + run_id: The workflow run ID. + workflow_type: The type/name of the workflow. + task_queue: The task queue the workflow is running on. + namespace: The namespace the workflow is running in. + attempt: The current workflow task attempt number (starting from 1). + """ + + namespace: str + workflow_id: str + + +class WithSerializationContext(ABC): + """Protocol for objects that can use serialization context. + + This is similar to the .NET IWithSerializationContext interface. + Objects implementing this protocol can receive contextual information + during serialization and deserialization. + """ + + @abstractmethod + def with_context(self, context: Optional[SerializationContext]) -> Self: + """Return a copy of this object configured to use the given context. + + Args: + context: The serialization context to use, or None for no context. + + Returns: + A new instance configured with the context. + """ + raise NotImplementedError() + + class PayloadConverter(ABC): """Base payload converter to/from multiple payloads/values.""" @@ -1206,6 +1278,32 @@ async def decode_failure( await self.payload_codec.decode_failure(failure) return self.failure_converter.from_failure(failure, self.payload_converter) + def _with_context(self, context: Optional[SerializationContext]) -> Self: + new_self = type(self).__new__(type(self)) + setattr( + new_self, + "payload_converter", + self.payload_converter.with_context(context) + if isinstance(self.payload_converter, WithSerializationContext) + else self.payload_converter, + ) + codec = self.payload_codec + setattr( + new_self, + "payload_codec", + cast(WithSerializationContext, codec).with_context(context) + if isinstance(codec, WithSerializationContext) + else codec, + ) + setattr( + new_self, + "failure_converter", + self.failure_converter.with_context(context) + if isinstance(self.failure_converter, WithSerializationContext) + else self.failure_converter, + ) + return new_self + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 118966b34..7ebe47351 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -65,6 +65,11 @@ from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure +from ..converter import ( + ActivitySerializationContext, + WithSerializationContext, + WorkflowSerializationContext, +) from ._interceptor import ( ContinueAsNewInput, ExecuteWorkflowInput, @@ -208,6 +213,19 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: WorkflowInstance.__init__(self) temporalio.workflow._Runtime.__init__(self) self._payload_converter = det.payload_converter_class() + + # Apply serialization context to payload converter + self._payload_converter = ( + self._payload_converter.with_context( + WorkflowSerializationContext( + namespace=det.info.namespace, + workflow_id=det.info.workflow_id, + ) + ) + if isinstance(self._payload_converter, WithSerializationContext) + else self._payload_converter + ) + self._failure_converter = det.failure_converter_class() self._defn = det.defn self._workflow_input: Optional[ExecuteWorkflowInput] = None @@ -1017,6 +1035,7 @@ def _apply_update_random_seed( def _make_workflow_input( self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow ) -> ExecuteWorkflowInput: + print("Making workflow input") # Set arg types, using raw values for dynamic arg_types = self._defn.arg_types if not self._defn.name: @@ -1987,6 +2006,7 @@ def _convert_payloads( if types and len(types) != len(payloads): types = None try: + print(f"Converting payloads with {self._payload_converter}.") return self._payload_converter.from_payloads( payloads, type_hints=types, @@ -2769,9 +2789,27 @@ def _apply_schedule_command( temporalio.bridge.proto.activity_result.DoBackoff ] = None, ) -> None: + # Set up serialization context + payload_converter = ( + self._instance._payload_converter.with_context( + ActivitySerializationContext( + namespace=self._instance.workflow_info().namespace, + workflow_id=self._instance.workflow_info().workflow_id, + workflow_type=self._instance.workflow_info().workflow_type, + activity_type=self._input.activity, + activity_task_queue=self._input.task_queue + if isinstance(self._input, StartActivityInput) + else None, + is_local=isinstance(self._input, StartLocalActivityInput), + ) + ) + if isinstance(self._instance._payload_converter, WithSerializationContext) + else self._instance._payload_converter + ) + # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + payload_converter.to_payloads(self._input.args) if self._input.args else None ) @@ -2807,7 +2845,7 @@ def _apply_schedule_command( self._input.retry_policy.apply_to_proto(v.retry_policy) if self._input.summary: command.user_metadata.summary.CopyFrom( - self._instance._payload_converter.to_payload(self._input.summary) + payload_converter.to_payload(self._input.summary) ) v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, @@ -2919,9 +2957,21 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_start_command(self) -> None: + # Set up serialization context + payload_converter = ( + self._instance._payload_converter.with_context( + WorkflowSerializationContext( + namespace=self._instance.workflow_info().namespace, + workflow_id=self._instance.workflow_info().workflow_id, + ) + ) + if isinstance(self._instance._payload_converter, WithSerializationContext) + else self._instance._payload_converter + ) + # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + payload_converter.to_payloads(self._input.args) if self._input.args else None ) @@ -2956,9 +3006,7 @@ def _apply_start_command(self) -> None: temporalio.common._apply_headers(self._input.headers, v.headers) if self._input.memo: for k, val in self._input.memo.items(): - v.memo[k].CopyFrom( - self._instance._payload_converter.to_payloads([val])[0] - ) + v.memo[k].CopyFrom(payload_converter.to_payloads([val])[0]) if self._input.search_attributes: _encode_search_attributes( self._input.search_attributes, v.search_attributes @@ -3126,15 +3174,27 @@ def __init__( self._input = input def _apply_command(self) -> None: + # Set up serialization context + payload_converter = ( + self._instance._payload_converter.with_context( + WorkflowSerializationContext( + namespace=self._instance.workflow_info().namespace, + workflow_id=self._instance.workflow_info().workflow_id, + ) + ) + if isinstance(self._instance._payload_converter, WithSerializationContext) + else self._instance._payload_converter + ) + # Convert arguments before creating command in case it raises error payloads = ( - self._instance._payload_converter.to_payloads(self._input.args) + payload_converter.to_payloads(self._input.args) if self._input.args else None ) memo_payloads = ( { - k: self._instance._payload_converter.to_payloads([val])[0] + k: payload_converter.to_payloads([val])[0] for k, val in self._input.memo.items() } if self._input.memo diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index d13debf12..443c759d8 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import base64 import concurrent.futures import dataclasses import json @@ -95,6 +96,8 @@ DefaultPayloadConverter, PayloadCodec, PayloadConverter, + SerializationContext, + WithSerializationContext, JSONPlainPayloadConverter,CompositePayloadConverter ) from temporalio.exceptions import ( ActivityError, @@ -8367,3 +8370,334 @@ async def test_previous_run_failure(client: Client): ) result = await handle.result() assert result == "Done" + + +@dataclass +class ContextInfo: + """Information extracted from serialization context.""" + activity: bool = False + workflow: bool = False + workflow_id: str = "" + + @classmethod + def create(cls, context: Optional[SerializationContext]) -> ContextInfo: + """Create ContextInfo from a serialization context.""" + if context is None: + return cls() + + from temporalio.converter import ( + ActivitySerializationContext, + WorkflowSerializationContext, + ) + + if isinstance(context, ActivitySerializationContext): + return cls( + activity=True, + workflow=False, + workflow_id=context.workflow_id, + ) + elif isinstance(context, WorkflowSerializationContext): + return cls( + activity=False, + workflow=True, + workflow_id=context.workflow_id, + ) + else: + return cls() + + +@dataclass +class ContextEvent: + """Records when context was used during serialization/deserialization.""" + outbound: bool # True for serialization, False for deserialization + info: ContextInfo + + +@dataclass +class ContextValue: + """Test value that tracks context events during serialization.""" + name: str + events: List[ContextEvent] = dataclasses.field(default_factory=list) + + def assert_workflow_equal( + self, + expected_name: str, + workflow_id: str, + activity: bool = False + ): + """Assert this value has the expected events for workflow context.""" + assert self.name == expected_name, f"Expected name {expected_name}, got {self.name}" + + # Should have outbound (serialization) and inbound (deserialization) events + outbound_events = [e for e in self.events if e.outbound] + inbound_events = [e for e in self.events if not e.outbound] + + assert len(outbound_events) >= 1, "Should have at least one outbound event" + assert len(inbound_events) >= 1, "Should have at least one inbound event" + + # Check the context info matches expectations + for event in outbound_events + inbound_events: + assert event.info.workflow_id == workflow_id + assert event.info.activity == activity + + +class ContextJsonConverter(JSONPlainPayloadConverter): + """Test JSON converter that tracks context usage.""" + + def __init__(self, context_info: Optional[ContextInfo] = None): + super().__init__() + self.context_info = context_info or ContextInfo() + + def to_payload(self, value: Any) -> Optional[Payload]: + print("To Payload:", value) + # Track context usage during serialization + if isinstance(value, ContextValue): + value.events.append( + ContextEvent(outbound=True, info=self.context_info) + ) + return super().to_payload(value) + + def from_payload( + self, payload: Payload, type_hint: Optional[Type] = None + ) -> Any: + print("From Payload:", payload) + value = super().from_payload(payload, type_hint) + # Track context usage during deserialization + if isinstance(value, ContextValue): + value.events.append( + ContextEvent(outbound=False, info=self.context_info) + ) + return value + + +class ContextPayloadConverter(CompositePayloadConverter): + def __init__(self, context_info: Optional[ContextInfo] = None): + super().__init__(ContextJsonConverter(context_info)) + + def with_context(self, context: Optional[SerializationContext]) -> ContextPayloadConverter: + print("With context:", context) + """Return a copy configured with the given context.""" + return ContextPayloadConverter(ContextInfo.create(context)) + + +class ContextFailureConverter(DefaultFailureConverter, WithSerializationContext): + """Test failure converter that adds context information to failure messages.""" + + def __init__(self, context_info: Optional[ContextInfo] = None): + super().__init__() + self.context_info = context_info or ContextInfo() + + def to_failure(self, exception: BaseException, payload_converter: PayloadConverter, failure: Failure): + super().to_failure(exception, payload_converter, failure) + + # Add context info to application failure messages + if (failure.application_failure_info is not None and + "[activity:" not in failure.message): + activity_str = "true" if self.context_info.activity else "false" + failure.message += f" [activity: {activity_str}, workflow-id: {self.context_info.workflow_id}]" + + + def with_context(self, context: Optional[SerializationContext]) -> "ContextFailureConverter": + """Return a copy configured with the given context.""" + return ContextFailureConverter(ContextInfo.create(context)) + + +class ContextPayloadCodec(PayloadCodec, WithSerializationContext): + """Test codec that validates context during encode/decode.""" + + ENCODING_NAME = "context-encoding" + + def __init__(self, context_info: Optional[ContextInfo] = None): + self.context_info = context_info or ContextInfo() + + async def encode(self, payloads: Sequence[Payload]) -> List[Payload]: + """Encode payloads and add context metadata.""" + encoded_payloads = [] + + for payload in payloads: + # Encode the original payload as base64 + encoded_data = base64.b64encode(payload.SerializeToString()) + + # Create new payload with context metadata + new_payload = Payload( + data=encoded_data, + metadata={ + "encoding": self.ENCODING_NAME.encode(), + "activity": str(self.context_info.activity).lower().encode(), + "workflow-id": self.context_info.workflow_id.encode(), + }, + ) + encoded_payloads.append(new_payload) + + return encoded_payloads + + async def decode(self, payloads: Sequence[Payload]) -> List[Payload]: + """Decode payloads and validate context metadata.""" + decoded_payloads = [] + + for payload in payloads: + # Validate context metadata matches current context + assert payload.metadata.get("encoding", b"").decode() == self.ENCODING_NAME + assert payload.metadata.get("activity", b"").decode() == str(self.context_info.activity).lower() + assert payload.metadata.get("workflow-id", b"").decode() == self.context_info.workflow_id + + # Decode the original payload + original_payload = Payload() + original_payload.ParseFromString(base64.b64decode(payload.data)) + decoded_payloads.append(original_payload) + + return decoded_payloads + + def with_context(self, context: Optional[SerializationContext]) -> "ContextPayloadCodec": + """Return a copy configured with the given context.""" + return ContextPayloadCodec(ContextInfo.create(context)) + + +@workflow.defn +class ConverterContextWorkflow: + """Test workflow that exercises serialization context.""" + + def __init__(self): + self._complete = False + + @workflow.run + async def run(self, value: ContextValue) -> ContextValue: + value.assert_workflow_equal("workflow-input", workflow.info().workflow_id) + await workflow.wait_condition(lambda: self._complete) + return ContextValue("workflow-result") + + @workflow.signal + async def complete(self): + self._complete = True + + @workflow.signal + async def some_signal(self, value: ContextValue): + value.assert_workflow_equal("signal-input", workflow.info().workflow_id) + + @workflow.query + def some_query(self, value: ContextValue) -> ContextValue: + value.assert_workflow_equal("query-input", workflow.info().workflow_id) + return ContextValue("query-result") + + @workflow.update + async def some_update(self, value: ContextValue) -> ContextValue: + value.assert_workflow_equal("update-input", workflow.info().workflow_id) + return ContextValue("update-result") + + @workflow.update + async def signal_external(self, workflow_id: str): + handle = workflow.get_external_workflow_handle(workflow_id) + await handle.signal(ConverterContextWorkflow.some_signal, ContextValue("signal-input")) + + @workflow.update + async def do_child(self): + # Start child + handle = await workflow.start_child_workflow( + ConverterContextWorkflow.run, + ContextValue("workflow-input"), + id=f"child-{uuid.uuid4()}" + ) + await handle.signal(ConverterContextWorkflow.some_signal, ContextValue("signal-input")) + await handle.signal(ConverterContextWorkflow.complete) + res = await handle + res.assert_workflow_equal("workflow-result", handle.id) + + @workflow.update + async def do_activity(self): + # Regular activity + res = await workflow.execute_activity( + some_activity, + ContextValue("activity-input"), + start_to_close_timeout=timedelta(seconds=30) + ) + res.assert_workflow_equal("activity-result", workflow.info().workflow_id, activity=True) + + # Local activity + res = await workflow.execute_local_activity( + some_activity, + ContextValue("activity-input"), + start_to_close_timeout=timedelta(seconds=30) + ) + res.assert_workflow_equal("activity-result", workflow.info().workflow_id, activity=True) + + @workflow.update + async def do_update_and_activity_failure(self): + try: + await workflow.execute_activity( + some_failing_activity, + start_to_close_timeout=timedelta(seconds=10) + ) + raise RuntimeError("Expected failure") + except ActivityError as e: + if e.cause and "Intentional activity failure" in str(e.cause): + raise ApplicationError( + "Intentional update failure", + ContextValue("update-failure"), + cause=e + ) + raise + + @workflow.update + async def do_async_activity_completion(self): + res = await workflow.execute_activity( + some_async_completing_activity, + start_to_close_timeout=timedelta(seconds=30) + ) + res.assert_workflow_equal("activity-async-result", workflow.info().workflow_id, activity=True) + + +@activity.defn +async def some_activity(value: ContextValue) -> ContextValue: + """Test activity that validates serialization context.""" + value.assert_workflow_equal( + "activity-input", + activity.info().workflow_id, + activity=True + ) + return ContextValue("activity-result") + + +@activity.defn +async def some_failing_activity(): + """Test activity that intentionally fails.""" + raise ApplicationError( + "Intentional activity failure", + ContextValue("activity-failure"), + non_retryable=True + ) + + +@activity.defn +async def some_async_completing_activity() -> ContextValue: + """Test activity that completes asynchronously.""" + # Get async completion handle + handle = activity.async_completion_handle() + + # Schedule completion in background + async def complete_later(): + await asyncio.sleep(0.3) + await handle.complete(ContextValue("activity-async-result")) + + asyncio.create_task(complete_later()) + activity.raise_complete_async() + +async def test_serialization_context(client: Client): + config = client.config() + config["data_converter"] = DataConverter( + payload_converter_class=ContextJsonConverter, payload_codec=ContextPayloadCodec(), failure_converter_class=ContextFailureConverter + ) + client = Client(**config) + + async with new_worker( + client, ConverterContextWorkflow, activities=[some_activity, some_failing_activity, some_async_completing_activity] + ) as worker: + handle = await client.start_workflow( + ConverterContextWorkflow.run, + ContextValue("workflow-input"), + id=f"serialization-context-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + result = await handle.result() + #assert any(event.HasField("workflow_execution_started_event_attributes") and event.workflow_execution_started_event_attributes.input.payloads[0].metadata.get("") async for event in handle.fetch_history_events()) + assert result == "Hello, Temporal!" + assert False