Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,51 +93,50 @@ class SerializationContext(ABC):


@dataclass(frozen=True)
class BaseWorkflowSerializationContext(SerializationContext):
"""Base serialization context shared by workflow and activity serialization contexts."""

namespace: str
workflow_id: str


@dataclass(frozen=True)
class WorkflowSerializationContext(BaseWorkflowSerializationContext):
class WorkflowSerializationContext(SerializationContext):
"""Serialization context for workflows.

See :py:class:`SerializationContext` for more details.

Attributes:
namespace: The namespace the workflow is running in.
workflow_id: The ID of the workflow. Note that this is the ID of the workflow of which the
payload being operated on is an input or output. Note also that when creating/describing
schedules, this may be the workflow ID prefix as configured, not the final workflow ID
when the workflow is created by the schedule.
"""

pass
namespace: str
"""Namespace used by the worker executing the workflow."""

workflow_id: Optional[str]
"""Workflow ID.

Note that this is the ID of the workflow of which the payload being operated on is an input or
output. When creating/describing schedules, this may be the workflow ID prefix as configured,
not the final workflow ID when the workflow is created by the schedule."""


@dataclass(frozen=True)
class ActivitySerializationContext(BaseWorkflowSerializationContext):
class ActivitySerializationContext(SerializationContext):
"""Serialization context for activities.

See :py:class:`SerializationContext` for more details.

Attributes:
namespace: Workflow/activity namespace.
workflow_id: Workflow ID. Note, when creating/describing schedules,
this may be the workflow ID prefix as configured, not the final workflow ID when the
workflow is created by the schedule.
workflow_type: Workflow Type.
activity_type: Activity Type.
activity_task_queue: Activity task queue.
is_local: Whether the activity is a local activity.
"""

workflow_type: str
namespace: str
"""Namespace used by the worker executing the activity."""

activity_id: Optional[str]
"""Activity ID."""

activity_type: str
"""Activity type."""

activity_task_queue: str
"""Activity task queue."""

workflow_id: Optional[str]
"""Workflow ID."""

workflow_type: Optional[str]
"""Workflow type."""

is_local: bool
"""Whether the activity is a local activity."""


class WithSerializationContext(ABC):
Expand Down
10 changes: 6 additions & 4 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,12 @@ async def _heartbeat_async(
if activity.info:
context = temporalio.converter.ActivitySerializationContext(
namespace=activity.info.workflow_namespace,
workflow_id=activity.info.workflow_id,
workflow_type=activity.info.workflow_type,
activity_id=activity.info.activity_id,
activity_type=activity.info.activity_type,
activity_task_queue=self._task_queue,
is_local=activity.info.is_local,
workflow_id=activity.info.workflow_id,
workflow_type=activity.info.workflow_type,
)
data_converter = data_converter.with_context(context)

Expand Down Expand Up @@ -308,11 +309,12 @@ async def _handle_start_activity_task(
# Create serialization context for the activity
context = temporalio.converter.ActivitySerializationContext(
namespace=start.workflow_namespace,
workflow_id=start.workflow_execution.workflow_id,
workflow_type=start.workflow_type,
activity_id=start.activity_id,
activity_type=start.activity_type,
activity_task_queue=self._task_queue,
is_local=start.is_local,
workflow_id=start.workflow_execution.workflow_id,
workflow_type=start.workflow_type,
)
data_converter = self._data_converter.with_context(context)
try:
Expand Down
15 changes: 9 additions & 6 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,15 +788,16 @@ def _apply_resolve_activity(
raise RuntimeError(f"Failed finding activity handle for sequence {job.seq}")
activity_context = temporalio.converter.ActivitySerializationContext(
namespace=self._info.namespace,
workflow_id=self._info.workflow_id,
workflow_type=self._info.workflow_type,
activity_id=handle._input.activity_id,
activity_type=handle._input.activity,
activity_task_queue=(
handle._input.task_queue or self._info.task_queue
if isinstance(handle._input, StartActivityInput)
else self._info.task_queue
),
is_local=isinstance(handle._input, StartLocalActivityInput),
workflow_id=self._info.workflow_id,
workflow_type=self._info.workflow_type,
)
payload_converter = self._payload_converter_with_context(activity_context)
failure_converter = self._failure_converter_with_context(activity_context)
Expand Down Expand Up @@ -2127,8 +2128,7 @@ def get_serialization_context(
activity_handle = self._pending_activities[command_info.command_seq]
return temporalio.converter.ActivitySerializationContext(
namespace=self._info.namespace,
workflow_id=self._info.workflow_id,
workflow_type=self._info.workflow_type,
activity_id=activity_handle._input.activity_id,
activity_type=activity_handle._input.activity,
activity_task_queue=(
activity_handle._input.task_queue
Expand All @@ -2137,6 +2137,8 @@ def get_serialization_context(
else self._info.task_queue
),
is_local=isinstance(activity_handle._input, StartLocalActivityInput),
workflow_id=self._info.workflow_id,
workflow_type=self._info.workflow_type,
)

elif (
Expand Down Expand Up @@ -2921,15 +2923,16 @@ def __init__(
self._payload_converter = self._instance._payload_converter_with_context(
temporalio.converter.ActivitySerializationContext(
namespace=self._instance._info.namespace,
workflow_id=self._instance._info.workflow_id,
workflow_type=self._instance._info.workflow_type,
activity_id=self._input.activity_id,
activity_type=self._input.activity,
activity_task_queue=(
self._input.task_queue or self._instance._info.task_queue
if isinstance(self._input, StartActivityInput)
else self._instance._info.task_queue
),
is_local=isinstance(self._input, StartLocalActivityInput),
workflow_id=self._instance._info.workflow_id,
workflow_type=self._instance._info.workflow_type,
)
)

Expand Down
24 changes: 19 additions & 5 deletions tests/test_serialization_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def run(self, data: TraceData) -> TraceData:
data,
start_to_close_timeout=timedelta(seconds=10),
heartbeat_timeout=timedelta(seconds=2),
activity_id="activity-id",
)
data = await workflow.execute_child_workflow(
EchoWorkflow.run, data, id=f"{workflow.info().workflow_id}_child"
Expand Down Expand Up @@ -231,6 +232,7 @@ async def test_payload_conversion_calls_follow_expected_sequence_and_contexts(
workflow_id=workflow_id,
workflow_type=PayloadConversionWorkflow.__name__,
activity_type=passthrough_activity.__name__,
activity_id="activity-id",
activity_task_queue=task_queue,
is_local=False,
)
Expand Down Expand Up @@ -328,6 +330,7 @@ async def run(self) -> TraceData:
initial_interval=timedelta(milliseconds=100),
maximum_attempts=2,
),
activity_id="activity-id",
)


Expand Down Expand Up @@ -370,6 +373,7 @@ async def test_heartbeat_details_payload_conversion(client: Client):
workflow_id=workflow_id,
workflow_type=HeartbeatDetailsSerializationContextTestWorkflow.__name__,
activity_type=activity_with_heartbeat_details.__name__,
activity_id="activity-id",
activity_task_queue=task_queue,
is_local=False,
)
Expand Down Expand Up @@ -419,6 +423,7 @@ async def run(self, data: TraceData) -> TraceData:
local_activity,
data,
start_to_close_timeout=timedelta(seconds=10),
activity_id="activity-id",
)


Expand Down Expand Up @@ -459,6 +464,7 @@ async def test_local_activity_payload_conversion(client: Client):
workflow_id=workflow_id,
workflow_type=LocalActivityWorkflow.__name__,
activity_type=local_activity.__name__,
activity_id="activity-id",
activity_task_queue=task_queue,
is_local=True,
)
Expand Down Expand Up @@ -504,7 +510,7 @@ async def test_local_activity_payload_conversion(client: Client):


@workflow.defn
class EventWorkflow:
class WaitForSignalWorkflow:
# Like a global asyncio.Event()

def __init__(self) -> None:
Expand All @@ -521,10 +527,11 @@ def signal(self) -> None:

@activity.defn
async def async_activity() -> TraceData:
# Notify test that the activity has started and is ready to be completed manually
await (
activity.client()
.get_workflow_handle("activity-started-wf-id")
.signal(EventWorkflow.signal)
.signal(WaitForSignalWorkflow.signal)
)
activity.raise_complete_async()

Expand Down Expand Up @@ -558,7 +565,7 @@ async def test_async_activity_completion_payload_conversion(
task_queue=task_queue,
workflows=[
AsyncActivityCompletionSerializationContextTestWorkflow,
EventWorkflow,
WaitForSignalWorkflow,
],
activities=[async_activity],
workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance
Expand All @@ -572,12 +579,13 @@ async def test_async_activity_completion_payload_conversion(
workflow_id=workflow_id,
workflow_type=AsyncActivityCompletionSerializationContextTestWorkflow.__name__,
activity_type=async_activity.__name__,
activity_id="async-activity-id",
activity_task_queue=task_queue,
is_local=False,
)

act_started_wf_handle = await client.start_workflow(
EventWorkflow.run,
WaitForSignalWorkflow.run,
id="activity-started-wf-id",
task_queue=task_queue,
)
Expand Down Expand Up @@ -644,6 +652,7 @@ def test_subclassed_async_activity_handle(client: Client):
workflow_id="workflow-id",
workflow_type="workflow-type",
activity_type="activity-type",
activity_id="activity-id",
activity_task_queue="activity-task-queue",
is_local=False,
)
Expand Down Expand Up @@ -1058,11 +1067,12 @@ async def run(self) -> Never:
failing_activity,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(maximum_attempts=1),
activity_id="activity-id",
)
raise Exception("Unreachable")


test_traces: dict[str, list[TraceItem]] = defaultdict(list)
test_traces: dict[Optional[str], list[TraceItem]] = defaultdict(list)


class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext):
Expand Down Expand Up @@ -1154,6 +1164,7 @@ async def test_failure_converter_with_context(client: Client):
workflow_id=workflow_id,
workflow_type=FailureConverterTestWorkflow.__name__,
activity_type=failing_activity.__name__,
activity_id="activity-id",
activity_task_queue=task_queue,
is_local=False,
)
Expand Down Expand Up @@ -1322,6 +1333,7 @@ async def run(self, data: str) -> str:
codec_test_local_activity,
data,
start_to_close_timeout=timedelta(seconds=10),
activity_id="activity-id",
)


Expand Down Expand Up @@ -1360,6 +1372,7 @@ async def test_local_activity_codec_with_context(client: Client):
workflow_id=workflow_id,
workflow_type=LocalActivityCodecTestWorkflow.__name__,
activity_type=codec_test_local_activity.__name__,
activity_id="activity-id",
activity_task_queue=task_queue,
is_local=True,
)
Expand Down Expand Up @@ -1593,6 +1606,7 @@ async def run(self, data: str) -> str:
payload_encryption_activity,
"outbound",
start_to_close_timeout=timedelta(seconds=10),
activity_id="activity-id",
),
workflow.execute_child_workflow(
PayloadEncryptionChildWorkflow.run,
Expand Down
Loading