Skip to content

Commit cbfd2f3

Browse files
committed
Make it work
1 parent ba267bb commit cbfd2f3

File tree

4 files changed

+58
-42
lines changed

4 files changed

+58
-42
lines changed

temporalio/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5930,8 +5930,13 @@ async def _populate_start_workflow_execution_request(
59305930
req.workflow_type.name = input.workflow
59315931
req.task_queue.name = input.task_queue
59325932
if input.args:
5933+
context = temporalio.converter.WorkflowSerializationContext(
5934+
namespace=self._client.namespace, workflow_id=input.id
5935+
)
59335936
req.input.payloads.extend(
5934-
await self._client.data_converter.encode(input.args)
5937+
await self._client.data_converter._with_context(context).encode(
5938+
input.args
5939+
)
59355940
)
59365941
if input.execution_timeout is not None:
59375942
req.workflow_execution_timeout.FromTimedelta(input.execution_timeout)

temporalio/converter.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,35 +1311,26 @@ async def decode_failure(
13111311
return self.failure_converter.from_failure(failure, self.payload_converter)
13121312

13131313
def _with_context(self, context: Optional[SerializationContext]) -> Self:
1314-
new_self = type(self).__new__(type(self))
1315-
setattr(
1316-
new_self,
1317-
"payload_converter",
1318-
(
1319-
self.payload_converter.with_context(context)
1320-
if isinstance(self.payload_converter, WithSerializationContext)
1321-
else self.payload_converter
1322-
),
1314+
payload_converter = (
1315+
self.payload_converter.with_context(context)
1316+
if isinstance(self.payload_converter, WithSerializationContext)
1317+
else self.payload_converter
13231318
)
1324-
setattr(
1325-
new_self,
1326-
"payload_codec",
1327-
(
1328-
self.payload_codec.with_context(context)
1329-
if isinstance(self.payload_codec, WithSerializationContext)
1330-
else self.payload_codec
1331-
),
1319+
payload_codec = (
1320+
self.payload_codec.with_context(context)
1321+
if isinstance(self.payload_codec, WithSerializationContext)
1322+
else self.payload_codec
13321323
)
1333-
setattr(
1334-
new_self,
1335-
"failure_converter",
1336-
(
1337-
self.failure_converter.with_context(context)
1338-
if isinstance(self.failure_converter, WithSerializationContext)
1339-
else self.failure_converter
1340-
),
1324+
failure_converter = (
1325+
self.failure_converter.with_context(context)
1326+
if isinstance(self.failure_converter, WithSerializationContext)
1327+
else self.failure_converter
13411328
)
1342-
return new_self
1329+
cloned = dataclasses.replace(self)
1330+
object.__setattr__(cloned, "payload_converter", payload_converter)
1331+
object.__setattr__(cloned, "payload_codec", payload_codec)
1332+
object.__setattr__(cloned, "failure_converter", failure_converter)
1333+
return cloned
13431334

13441335

13451336
DefaultPayloadConverter.default_encoding_payload_converters = (

temporalio/worker/_workflow_instance.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,14 @@ def _apply_initialize_workflow(
984984
async def run_workflow(input: ExecuteWorkflowInput) -> None:
985985
try:
986986
result = await self._inbound.execute_workflow(input)
987-
result_payloads = self._payload_converter.to_payloads([result])
987+
converter = self._payload_converter
988+
if isinstance(converter, temporalio.converter.WithSerializationContext):
989+
context = temporalio.converter.WorkflowSerializationContext(
990+
namespace=self._info.namespace,
991+
workflow_id=self._info.workflow_id,
992+
)
993+
converter = converter.with_context(context)
994+
result_payloads = converter.to_payloads([result])
988995
if len(result_payloads) != 1:
989996
raise ValueError(
990997
f"Expected 1 result payload, got {len(result_payloads)}"

tests/test_serialization_context.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import uuid
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from typing import Any, Optional, Type
77

88
from temporalio import workflow
@@ -21,14 +21,22 @@
2121

2222

2323
@dataclass
24-
class WorkflowData:
25-
workflow_context: Optional[WorkflowSerializationContext] = None
24+
class PayloadConverterTraceData:
25+
to_payload: Optional[WorkflowSerializationContext] = None
26+
from_payload: Optional[WorkflowSerializationContext] = None
27+
28+
29+
@dataclass
30+
class TraceData:
31+
workflow_context: PayloadConverterTraceData = field(
32+
default_factory=PayloadConverterTraceData
33+
)
2634

2735

2836
@workflow.defn
2937
class SerializationContextTestWorkflow:
3038
@workflow.run
31-
async def run(self, input: WorkflowData) -> WorkflowData:
39+
async def run(self, input: TraceData) -> TraceData:
3240
return input
3341

3442

@@ -51,9 +59,7 @@ def with_context(
5159
return SerializationContextTestEncodingPayloadConverter(context)
5260

5361
def to_payload(self, value: Any) -> Optional[Payload]:
54-
assert isinstance(value, WorkflowData)
55-
assert isinstance(self.context, WorkflowSerializationContext)
56-
value.workflow_context = self.context
62+
value.workflow_context.to_payload = self.context
5763
return None
5864

5965
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
@@ -62,11 +68,14 @@ def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> An
6268

6369

6470
class SerializationContextTestPayloadConverter(CompositePayloadConverter):
65-
def __init__(self):
66-
super().__init__(
67-
SerializationContextTestEncodingPayloadConverter(None),
68-
*DefaultPayloadConverter.default_encoding_payload_converters,
69-
)
71+
def __init__(self, *converters):
72+
# TODO: we cannot expect users to do this
73+
if not converters:
74+
converters = (
75+
SerializationContextTestEncodingPayloadConverter(None),
76+
*DefaultPayloadConverter.default_encoding_payload_converters,
77+
)
78+
super().__init__(*converters)
7079

7180

7281
data_converter = dataclasses.replace(
@@ -81,6 +90,10 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
8190
workflow_id = str(uuid.uuid4())
8291
task_queue = str(uuid.uuid4())
8392

93+
config = client.config()
94+
config["data_converter"] = data_converter
95+
client = Client(**config)
96+
8497
async with Worker(
8598
client,
8699
task_queue=task_queue,
@@ -89,12 +102,12 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
89102
):
90103
result = await client.execute_workflow(
91104
SerializationContextTestWorkflow.run,
92-
WorkflowData(),
105+
TraceData(),
93106
id=workflow_id,
94107
task_queue=task_queue,
95108
)
96109

97-
assert result.workflow_context == WorkflowSerializationContext(
110+
assert result.workflow_context.to_payload == WorkflowSerializationContext(
98111
namespace="default",
99112
workflow_id=workflow_id,
100113
)

0 commit comments

Comments
 (0)