Skip to content

Commit feb90f3

Browse files
committed
Update test
1 parent 586960e commit feb90f3

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

tests/test_serialization_context.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dataclasses import dataclass, field
66
from typing import Any, Optional, Type
77

8+
from typing_extensions import Self
9+
810
from temporalio import workflow
911
from temporalio.api.common.v1 import Payload
1012
from temporalio.client import Client
@@ -43,8 +45,8 @@ async def run(self, input: TraceData) -> TraceData:
4345
class SerializationContextTestEncodingPayloadConverter(
4446
EncodingPayloadConverter, WithSerializationContext
4547
):
46-
def __init__(self, context: Optional[SerializationContext]):
47-
self.context = context
48+
def __init__(self):
49+
self.context: Optional[SerializationContext] = None
4850

4951
@property
5052
def encoding(self) -> str:
@@ -56,7 +58,9 @@ def with_context(
5658
print(
5759
f"🌈 SerializationContextTestEncodingPayloadConverter.with_context({context})"
5860
)
59-
return SerializationContextTestEncodingPayloadConverter(context)
61+
converter = SerializationContextTestEncodingPayloadConverter()
62+
converter.context = context
63+
return converter
6064

6165
def to_payload(self, value: Any) -> Optional[Payload]:
6266
value.workflow_context.to_payload = self.context
@@ -67,15 +71,23 @@ def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> An
6771
# return payload.data.decode()
6872

6973

70-
class SerializationContextTestPayloadConverter(CompositePayloadConverter):
71-
def __init__(self, *converters):
72-
# TODO: we cannot expect users to do this
73-
if not converters:
74-
converters = (
75-
SerializationContextTestEncodingPayloadConverter(None),
76-
*DefaultPayloadConverter.default_encoding_payload_converters,
77-
)
78-
super().__init__(*converters)
74+
class SerializationContextTestPayloadConverter(
75+
CompositePayloadConverter, WithSerializationContext
76+
):
77+
def __init__(self):
78+
super().__init__(
79+
SerializationContextTestEncodingPayloadConverter(),
80+
*DefaultPayloadConverter.default_encoding_payload_converters,
81+
)
82+
83+
def with_context(self, context: Optional[SerializationContext]) -> Self:
84+
instance = type(self).__new__(type(self))
85+
converters = [
86+
c.with_context(context) if isinstance(c, WithSerializationContext) else c
87+
for c in self.converters.values()
88+
]
89+
CompositePayloadConverter.__init__(instance, *converters)
90+
return instance
7991

8092

8193
data_converter = dataclasses.replace(

0 commit comments

Comments
 (0)