55from dataclasses import dataclass , field
66from typing import Any , Optional , Type
77
8+ from typing_extensions import Self
9+
810from temporalio import workflow
911from temporalio .api .common .v1 import Payload
1012from temporalio .client import Client
@@ -43,8 +45,8 @@ async def run(self, input: TraceData) -> TraceData:
4345class SerializationContextTestEncodingPayloadConverter (
4446 EncodingPayloadConverter , WithSerializationContext
4547):
46- def __init__ (self , context : Optional [ SerializationContext ] ):
47- self .context = context
48+ def __init__ (self ):
49+ self .context : Optional [ SerializationContext ] = None
4850
4951 @property
5052 def encoding (self ) -> str :
@@ -56,7 +58,9 @@ def with_context(
5658 print (
5759 f"🌈 SerializationContextTestEncodingPayloadConverter.with_context({ context } )"
5860 )
59- return SerializationContextTestEncodingPayloadConverter (context )
61+ converter = SerializationContextTestEncodingPayloadConverter ()
62+ converter .context = context
63+ return converter
6064
6165 def to_payload (self , value : Any ) -> Optional [Payload ]:
6266 value .workflow_context .to_payload = self .context
@@ -67,15 +71,23 @@ def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> An
6771 # return payload.data.decode()
6872
6973
70- class SerializationContextTestPayloadConverter (CompositePayloadConverter ):
71- def __init__ (self , * converters ):
72- # TODO: we cannot expect users to do this
73- if not converters :
74- converters = (
75- SerializationContextTestEncodingPayloadConverter (None ),
76- * DefaultPayloadConverter .default_encoding_payload_converters ,
77- )
78- super ().__init__ (* converters )
74+ class SerializationContextTestPayloadConverter (
75+ CompositePayloadConverter , WithSerializationContext
76+ ):
77+ def __init__ (self ):
78+ super ().__init__ (
79+ SerializationContextTestEncodingPayloadConverter (),
80+ * DefaultPayloadConverter .default_encoding_payload_converters ,
81+ )
82+
83+ def with_context (self , context : Optional [SerializationContext ]) -> Self :
84+ instance = type (self ).__new__ (type (self ))
85+ converters = [
86+ c .with_context (context ) if isinstance (c , WithSerializationContext ) else c
87+ for c in self .converters .values ()
88+ ]
89+ CompositePayloadConverter .__init__ (instance , * converters )
90+ return instance
7991
8092
8193data_converter = dataclasses .replace (
0 commit comments