Skip to content

Commit 8af80df

Browse files
committed
Refactor
1 parent cdce3b8 commit 8af80df

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

tests/test_serialization_context.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import dataclasses
44
import uuid
55
from dataclasses import dataclass, field
6-
from typing import Any, Optional, Type
6+
from typing import Any, Literal, Optional, Type
77

88
from temporalio import workflow
99
from temporalio.api.common.v1 import Payload
1010
from temporalio.client import Client
1111
from temporalio.converter import (
12+
ActivitySerializationContext,
1213
CompositePayloadConverter,
1314
DataConverter,
1415
DefaultPayloadConverter,
@@ -22,20 +23,19 @@
2223

2324

2425
@dataclass
25-
class PayloadConverterTraceData:
26-
to_payload: list[WorkflowSerializationContext] = field(default_factory=list)
27-
from_payload: list[WorkflowSerializationContext] = field(default_factory=list)
26+
class TraceItem:
27+
context_type: Literal["workflow", "activity"]
28+
method: Literal["to_payload", "from_payload"]
29+
context: WorkflowSerializationContext | ActivitySerializationContext
2830

2931

3032
@dataclass
3133
class TraceData:
32-
workflow_context: PayloadConverterTraceData = field(
33-
default_factory=PayloadConverterTraceData
34-
)
34+
items: list[TraceItem] = field(default_factory=list)
3535

3636

3737
@workflow.defn(sandboxed=False) # we want to use isinstance
38-
class SerializationContextTestWorkflow:
38+
class PassThroughWorkflow:
3939
@workflow.run
4040
async def run(self, input: TraceData) -> TraceData:
4141
return input
@@ -64,14 +64,22 @@ def with_context(
6464
def to_payload(self, value: Any) -> Optional[Payload]:
6565
assert isinstance(value, TraceData)
6666
assert isinstance(self.context, WorkflowSerializationContext)
67-
value.workflow_context.to_payload.append(self.context)
67+
value.items.append(
68+
TraceItem(
69+
context_type="workflow", method="to_payload", context=self.context
70+
)
71+
)
6872
return None
6973

7074
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
7175
value = JSONPlainPayloadConverter().from_payload(payload, type_hint)
7276
assert isinstance(value, TraceData)
7377
assert isinstance(self.context, WorkflowSerializationContext)
74-
value.workflow_context.from_payload.append(self.context)
78+
value.items.append(
79+
TraceItem(
80+
context_type="workflow", method="from_payload", context=self.context
81+
)
82+
)
7583
return value
7684

7785

@@ -104,11 +112,11 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
104112
async with Worker(
105113
client,
106114
task_queue=task_queue,
107-
workflows=[SerializationContextTestWorkflow],
115+
workflows=[PassThroughWorkflow],
108116
activities=[],
109117
):
110118
result = await client.execute_workflow(
111-
SerializationContextTestWorkflow.run,
119+
PassThroughWorkflow.run,
112120
TraceData(),
113121
id=workflow_id,
114122
task_queue=task_queue,
@@ -118,5 +126,5 @@ async def test_workflow_payload_conversion_can_be_given_access_to_serialization_
118126
namespace="default",
119127
workflow_id=workflow_id,
120128
)
121-
assert result.workflow_context.to_payload == [workflow_context] * 2
122-
# assert result.workflow_context.from_payload == [workflow_context] * 2
129+
for item in result.items:
130+
print(item)

0 commit comments

Comments
 (0)