Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ async def decode_activation(
codec: temporalio.converter.PayloadCodec,
decode_headers: bool,
) -> None:
print("Decoding activation")
"""Decode the given activation with the codec."""
for job in act.jobs:
if job.HasField("query_workflow"):
Expand Down Expand Up @@ -462,6 +463,7 @@ async def encode_completion(
codec: temporalio.converter.PayloadCodec,
encode_headers: bool,
) -> None:
print("Encoding completion")
"""Recursively encode the given completion with the codec."""
if comp.HasField("failed"):
await codec.encode_failure(comp.failed.failure)
Expand Down
98 changes: 98 additions & 0 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uuid
import warnings
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from datetime import datetime
from enum import IntEnum
Expand All @@ -28,11 +29,14 @@
Mapping,
NewType,
Optional,
Protocol,
Self,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
Expand Down Expand Up @@ -65,6 +69,74 @@
logger = getLogger(__name__)


class SerializationContext(ABC):
"""Base serialization context.

This provides contextual information during serialization and deserialization
operations. Different contexts (activity, workflow, etc.) can provide
specialized information.
"""

pass


@dataclass(frozen=True)
class ActivitySerializationContext(SerializationContext):
"""Serialization context for activities.

Attributes:
activity_id: The ID of the activity.
activity_type: The type/name of the activity.
attempt: The current attempt number (starting from 1).
is_local: Whether this is a local activity.
"""

namespace: str
workflow_id: str
workflow_type: str
activity_type: str
activity_task_queue: Optional[str]
is_local: bool


@dataclass(frozen=True)
class WorkflowSerializationContext(SerializationContext):
"""Serialization context for workflows.

Attributes:
workflow_id: The workflow ID.
run_id: The workflow run ID.
workflow_type: The type/name of the workflow.
task_queue: The task queue the workflow is running on.
namespace: The namespace the workflow is running in.
attempt: The current workflow task attempt number (starting from 1).
"""

namespace: str
workflow_id: str


class WithSerializationContext(ABC):
"""Protocol for objects that can use serialization context.

This is similar to the .NET IWithSerializationContext<T> interface.
Objects implementing this protocol can receive contextual information
during serialization and deserialization.
"""

@abstractmethod
def with_context(self, context: Optional[SerializationContext]) -> Self:
"""Return a copy of this object configured to use the given context.

Args:
context: The serialization context to use, or None for no context.

Returns:
A new instance configured with the context.
"""
raise NotImplementedError()


class PayloadConverter(ABC):
"""Base payload converter to/from multiple payloads/values."""

Expand Down Expand Up @@ -1206,6 +1278,32 @@ async def decode_failure(
await self.payload_codec.decode_failure(failure)
return self.failure_converter.from_failure(failure, self.payload_converter)

def _with_context(self, context: Optional[SerializationContext]) -> Self:
new_self = type(self).__new__(type(self))
setattr(
new_self,
"payload_converter",
self.payload_converter.with_context(context)
if isinstance(self.payload_converter, WithSerializationContext)
else self.payload_converter,
)
codec = self.payload_codec
setattr(
new_self,
"payload_codec",
cast(WithSerializationContext, codec).with_context(context)
if isinstance(codec, WithSerializationContext)
else codec,
)
setattr(
new_self,
"failure_converter",
self.failure_converter.with_context(context)
if isinstance(self.failure_converter, WithSerializationContext)
else self.failure_converter,
)
return new_self


DefaultPayloadConverter.default_encoding_payload_converters = (
BinaryNullPayloadConverter(),
Expand Down
76 changes: 68 additions & 8 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
from temporalio.service import __version__

from ..api.failure.v1.message_pb2 import Failure
from ..converter import (
ActivitySerializationContext,
WithSerializationContext,
WorkflowSerializationContext,
)
from ._interceptor import (
ContinueAsNewInput,
ExecuteWorkflowInput,
Expand Down Expand Up @@ -208,6 +213,19 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
WorkflowInstance.__init__(self)
temporalio.workflow._Runtime.__init__(self)
self._payload_converter = det.payload_converter_class()

# Apply serialization context to payload converter
self._payload_converter = (
self._payload_converter.with_context(
WorkflowSerializationContext(
namespace=det.info.namespace,
workflow_id=det.info.workflow_id,
)
)
if isinstance(self._payload_converter, WithSerializationContext)
else self._payload_converter
)

self._failure_converter = det.failure_converter_class()
self._defn = det.defn
self._workflow_input: Optional[ExecuteWorkflowInput] = None
Expand Down Expand Up @@ -1017,6 +1035,7 @@ def _apply_update_random_seed(
def _make_workflow_input(
self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
) -> ExecuteWorkflowInput:
print("Making workflow input")
# Set arg types, using raw values for dynamic
arg_types = self._defn.arg_types
if not self._defn.name:
Expand Down Expand Up @@ -1987,6 +2006,7 @@ def _convert_payloads(
if types and len(types) != len(payloads):
types = None
try:
print(f"Converting payloads with {self._payload_converter}.")
return self._payload_converter.from_payloads(
payloads,
type_hints=types,
Expand Down Expand Up @@ -2769,9 +2789,27 @@ def _apply_schedule_command(
temporalio.bridge.proto.activity_result.DoBackoff
] = None,
) -> None:
# Set up serialization context
payload_converter = (
self._instance._payload_converter.with_context(
ActivitySerializationContext(
namespace=self._instance.workflow_info().namespace,
workflow_id=self._instance.workflow_info().workflow_id,
workflow_type=self._instance.workflow_info().workflow_type,
activity_type=self._input.activity,
activity_task_queue=self._input.task_queue
if isinstance(self._input, StartActivityInput)
else None,
is_local=isinstance(self._input, StartLocalActivityInput),
)
)
if isinstance(self._instance._payload_converter, WithSerializationContext)
else self._instance._payload_converter
)

# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)
Expand Down Expand Up @@ -2807,7 +2845,7 @@ def _apply_schedule_command(
self._input.retry_policy.apply_to_proto(v.retry_policy)
if self._input.summary:
command.user_metadata.summary.CopyFrom(
self._instance._payload_converter.to_payload(self._input.summary)
payload_converter.to_payload(self._input.summary)
)
v.cancellation_type = cast(
temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType,
Expand Down Expand Up @@ -2919,9 +2957,21 @@ def _resolve_failure(self, err: BaseException) -> None:
self._result_fut.set_result(None)

def _apply_start_command(self) -> None:
# Set up serialization context
payload_converter = (
self._instance._payload_converter.with_context(
WorkflowSerializationContext(
namespace=self._instance.workflow_info().namespace,
workflow_id=self._instance.workflow_info().workflow_id,
)
)
if isinstance(self._instance._payload_converter, WithSerializationContext)
else self._instance._payload_converter
)

# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)
Expand Down Expand Up @@ -2956,9 +3006,7 @@ def _apply_start_command(self) -> None:
temporalio.common._apply_headers(self._input.headers, v.headers)
if self._input.memo:
for k, val in self._input.memo.items():
v.memo[k].CopyFrom(
self._instance._payload_converter.to_payloads([val])[0]
)
v.memo[k].CopyFrom(payload_converter.to_payloads([val])[0])
if self._input.search_attributes:
_encode_search_attributes(
self._input.search_attributes, v.search_attributes
Expand Down Expand Up @@ -3126,15 +3174,27 @@ def __init__(
self._input = input

def _apply_command(self) -> None:
# Set up serialization context
payload_converter = (
self._instance._payload_converter.with_context(
WorkflowSerializationContext(
namespace=self._instance.workflow_info().namespace,
workflow_id=self._instance.workflow_info().workflow_id,
)
)
if isinstance(self._instance._payload_converter, WithSerializationContext)
else self._instance._payload_converter
)

# Convert arguments before creating command in case it raises error
payloads = (
self._instance._payload_converter.to_payloads(self._input.args)
payload_converter.to_payloads(self._input.args)
if self._input.args
else None
)
memo_payloads = (
{
k: self._instance._payload_converter.to_payloads([val])[0]
k: payload_converter.to_payloads([val])[0]
for k, val in self._input.memo.items()
}
if self._input.memo
Expand Down
Loading
Loading