Skip to content

Commit fb0784c

Browse files
committed
Very early WIP
1 parent 9372d47 commit fb0784c

File tree

4 files changed

+502
-8
lines changed

4 files changed

+502
-8
lines changed

temporalio/bridge/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ async def decode_activation(
384384
codec: temporalio.converter.PayloadCodec,
385385
decode_headers: bool,
386386
) -> None:
387+
print("Decoding activation")
387388
"""Decode the given activation with the codec."""
388389
for job in act.jobs:
389390
if job.HasField("query_workflow"):
@@ -462,6 +463,7 @@ async def encode_completion(
462463
codec: temporalio.converter.PayloadCodec,
463464
encode_headers: bool,
464465
) -> None:
466+
print("Encoding completion")
465467
"""Recursively encode the given completion with the codec."""
466468
if comp.HasField("failed"):
467469
await codec.encode_failure(comp.failed.failure)

temporalio/converter.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import uuid
1313
import warnings
1414
from abc import ABC, abstractmethod
15+
from copy import copy
1516
from dataclasses import dataclass
1617
from datetime import datetime
1718
from enum import IntEnum
@@ -28,11 +29,14 @@
2829
Mapping,
2930
NewType,
3031
Optional,
32+
Protocol,
33+
Self,
3134
Sequence,
3235
Tuple,
3336
Type,
3437
TypeVar,
3538
Union,
39+
cast,
3640
get_type_hints,
3741
overload,
3842
)
@@ -65,6 +69,74 @@
6569
logger = getLogger(__name__)
6670

6771

72+
class SerializationContext(ABC):
73+
"""Base serialization context.
74+
75+
This provides contextual information during serialization and deserialization
76+
operations. Different contexts (activity, workflow, etc.) can provide
77+
specialized information.
78+
"""
79+
80+
pass
81+
82+
83+
@dataclass(frozen=True)
84+
class ActivitySerializationContext(SerializationContext):
85+
"""Serialization context for activities.
86+
87+
Attributes:
88+
activity_id: The ID of the activity.
89+
activity_type: The type/name of the activity.
90+
attempt: The current attempt number (starting from 1).
91+
is_local: Whether this is a local activity.
92+
"""
93+
94+
namespace: str
95+
workflow_id: str
96+
workflow_type: str
97+
activity_type: str
98+
activity_task_queue: Optional[str]
99+
is_local: bool
100+
101+
102+
@dataclass(frozen=True)
103+
class WorkflowSerializationContext(SerializationContext):
104+
"""Serialization context for workflows.
105+
106+
Attributes:
107+
workflow_id: The workflow ID.
108+
run_id: The workflow run ID.
109+
workflow_type: The type/name of the workflow.
110+
task_queue: The task queue the workflow is running on.
111+
namespace: The namespace the workflow is running in.
112+
attempt: The current workflow task attempt number (starting from 1).
113+
"""
114+
115+
namespace: str
116+
workflow_id: str
117+
118+
119+
class WithSerializationContext(ABC):
120+
"""Protocol for objects that can use serialization context.
121+
122+
This is similar to the .NET IWithSerializationContext<T> interface.
123+
Objects implementing this protocol can receive contextual information
124+
during serialization and deserialization.
125+
"""
126+
127+
@abstractmethod
128+
def with_context(self, context: Optional[SerializationContext]) -> Self:
129+
"""Return a copy of this object configured to use the given context.
130+
131+
Args:
132+
context: The serialization context to use, or None for no context.
133+
134+
Returns:
135+
A new instance configured with the context.
136+
"""
137+
raise NotImplementedError()
138+
139+
68140
class PayloadConverter(ABC):
69141
"""Base payload converter to/from multiple payloads/values."""
70142

@@ -1206,6 +1278,32 @@ async def decode_failure(
12061278
await self.payload_codec.decode_failure(failure)
12071279
return self.failure_converter.from_failure(failure, self.payload_converter)
12081280

1281+
def _with_context(self, context: Optional[SerializationContext]) -> Self:
1282+
new_self = type(self).__new__(type(self))
1283+
setattr(
1284+
new_self,
1285+
"payload_converter",
1286+
self.payload_converter.with_context(context)
1287+
if isinstance(self.payload_converter, WithSerializationContext)
1288+
else self.payload_converter,
1289+
)
1290+
codec = self.payload_codec
1291+
setattr(
1292+
new_self,
1293+
"payload_codec",
1294+
cast(WithSerializationContext, codec).with_context(context)
1295+
if isinstance(codec, WithSerializationContext)
1296+
else codec,
1297+
)
1298+
setattr(
1299+
new_self,
1300+
"failure_converter",
1301+
self.failure_converter.with_context(context)
1302+
if isinstance(self.failure_converter, WithSerializationContext)
1303+
else self.failure_converter,
1304+
)
1305+
return new_self
1306+
12091307

12101308
DefaultPayloadConverter.default_encoding_payload_converters = (
12111309
BinaryNullPayloadConverter(),

temporalio/worker/_workflow_instance.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@
6565
from temporalio.service import __version__
6666

6767
from ..api.failure.v1.message_pb2 import Failure
68+
from ..converter import (
69+
ActivitySerializationContext,
70+
WithSerializationContext,
71+
WorkflowSerializationContext,
72+
)
6873
from ._interceptor import (
6974
ContinueAsNewInput,
7075
ExecuteWorkflowInput,
@@ -208,6 +213,19 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
208213
WorkflowInstance.__init__(self)
209214
temporalio.workflow._Runtime.__init__(self)
210215
self._payload_converter = det.payload_converter_class()
216+
217+
# Apply serialization context to payload converter
218+
self._payload_converter = (
219+
self._payload_converter.with_context(
220+
WorkflowSerializationContext(
221+
namespace=det.info.namespace,
222+
workflow_id=det.info.workflow_id,
223+
)
224+
)
225+
if isinstance(self._payload_converter, WithSerializationContext)
226+
else self._payload_converter
227+
)
228+
211229
self._failure_converter = det.failure_converter_class()
212230
self._defn = det.defn
213231
self._workflow_input: Optional[ExecuteWorkflowInput] = None
@@ -1017,6 +1035,7 @@ def _apply_update_random_seed(
10171035
def _make_workflow_input(
10181036
self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
10191037
) -> ExecuteWorkflowInput:
1038+
print("Making workflow input")
10201039
# Set arg types, using raw values for dynamic
10211040
arg_types = self._defn.arg_types
10221041
if not self._defn.name:
@@ -1987,6 +2006,7 @@ def _convert_payloads(
19872006
if types and len(types) != len(payloads):
19882007
types = None
19892008
try:
2009+
print(f"Converting payloads with {self._payload_converter}.")
19902010
return self._payload_converter.from_payloads(
19912011
payloads,
19922012
type_hints=types,
@@ -2769,9 +2789,27 @@ def _apply_schedule_command(
27692789
temporalio.bridge.proto.activity_result.DoBackoff
27702790
] = None,
27712791
) -> None:
2792+
# Set up serialization context
2793+
payload_converter = (
2794+
self._instance._payload_converter.with_context(
2795+
ActivitySerializationContext(
2796+
namespace=self._instance.workflow_info().namespace,
2797+
workflow_id=self._instance.workflow_info().workflow_id,
2798+
workflow_type=self._instance.workflow_info().workflow_type,
2799+
activity_type=self._input.activity,
2800+
activity_task_queue=self._input.task_queue
2801+
if isinstance(self._input, StartActivityInput)
2802+
else None,
2803+
is_local=isinstance(self._input, StartLocalActivityInput),
2804+
)
2805+
)
2806+
if isinstance(self._instance._payload_converter, WithSerializationContext)
2807+
else self._instance._payload_converter
2808+
)
2809+
27722810
# Convert arguments before creating command in case it raises error
27732811
payloads = (
2774-
self._instance._payload_converter.to_payloads(self._input.args)
2812+
payload_converter.to_payloads(self._input.args)
27752813
if self._input.args
27762814
else None
27772815
)
@@ -2807,7 +2845,7 @@ def _apply_schedule_command(
28072845
self._input.retry_policy.apply_to_proto(v.retry_policy)
28082846
if self._input.summary:
28092847
command.user_metadata.summary.CopyFrom(
2810-
self._instance._payload_converter.to_payload(self._input.summary)
2848+
payload_converter.to_payload(self._input.summary)
28112849
)
28122850
v.cancellation_type = cast(
28132851
temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType,
@@ -2919,9 +2957,21 @@ def _resolve_failure(self, err: BaseException) -> None:
29192957
self._result_fut.set_result(None)
29202958

29212959
def _apply_start_command(self) -> None:
2960+
# Set up serialization context
2961+
payload_converter = (
2962+
self._instance._payload_converter.with_context(
2963+
WorkflowSerializationContext(
2964+
namespace=self._instance.workflow_info().namespace,
2965+
workflow_id=self._instance.workflow_info().workflow_id,
2966+
)
2967+
)
2968+
if isinstance(self._instance._payload_converter, WithSerializationContext)
2969+
else self._instance._payload_converter
2970+
)
2971+
29222972
# Convert arguments before creating command in case it raises error
29232973
payloads = (
2924-
self._instance._payload_converter.to_payloads(self._input.args)
2974+
payload_converter.to_payloads(self._input.args)
29252975
if self._input.args
29262976
else None
29272977
)
@@ -2956,9 +3006,7 @@ def _apply_start_command(self) -> None:
29563006
temporalio.common._apply_headers(self._input.headers, v.headers)
29573007
if self._input.memo:
29583008
for k, val in self._input.memo.items():
2959-
v.memo[k].CopyFrom(
2960-
self._instance._payload_converter.to_payloads([val])[0]
2961-
)
3009+
v.memo[k].CopyFrom(payload_converter.to_payloads([val])[0])
29623010
if self._input.search_attributes:
29633011
_encode_search_attributes(
29643012
self._input.search_attributes, v.search_attributes
@@ -3126,15 +3174,27 @@ def __init__(
31263174
self._input = input
31273175

31283176
def _apply_command(self) -> None:
3177+
# Set up serialization context
3178+
payload_converter = (
3179+
self._instance._payload_converter.with_context(
3180+
WorkflowSerializationContext(
3181+
namespace=self._instance.workflow_info().namespace,
3182+
workflow_id=self._instance.workflow_info().workflow_id,
3183+
)
3184+
)
3185+
if isinstance(self._instance._payload_converter, WithSerializationContext)
3186+
else self._instance._payload_converter
3187+
)
3188+
31293189
# Convert arguments before creating command in case it raises error
31303190
payloads = (
3131-
self._instance._payload_converter.to_payloads(self._input.args)
3191+
payload_converter.to_payloads(self._input.args)
31323192
if self._input.args
31333193
else None
31343194
)
31353195
memo_payloads = (
31363196
{
3137-
k: self._instance._payload_converter.to_payloads([val])[0]
3197+
k: payload_converter.to_payloads([val])[0]
31383198
for k, val in self._input.memo.items()
31393199
}
31403200
if self._input.memo

0 commit comments

Comments
 (0)