Skip to content

Commit 2ea91d0

Browse files
committed
Fix test: return Self
1 parent 3a23ea4 commit 2ea91d0

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

temporalio/converter.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None:
349349
Args:
350350
converters: Payload converters to delegate to, in order.
351351
"""
352+
self._set_converters(*converters)
353+
354+
def _set_converters(self, *converters: EncodingPayloadConverter) -> None:
352355
self.converters = {c.encoding.encode(): c for c in converters}
353356

354357
def to_payloads(
@@ -413,16 +416,26 @@ def from_payloads(
413416
) from err
414417
return values
415418

416-
def with_context(self, context: SerializationContext) -> CompositePayloadConverter:
417-
"""Return a new instance with context set on the component converters"""
418-
return CompositePayloadConverter(
419-
*(
420-
c.with_context(context)
421-
if isinstance(c, WithSerializationContext)
422-
else c
423-
for c in self.converters.values()
424-
)
425-
)
419+
def with_context(self, context: SerializationContext) -> Self:
420+
"""Return a new instance with context set on the component converters.
421+
422+
If none of the component converters support with_context, return self.
423+
"""
424+
converters: list[EncodingPayloadConverter] = []
425+
any_with_context = False
426+
for c in self.converters.values():
427+
if isinstance(c, WithSerializationContext):
428+
converters.append(c.with_context(context))
429+
any_with_context = True
430+
else:
431+
converters.append(c)
432+
433+
if not any_with_context:
434+
return self
435+
436+
new_instance = type(self)()
437+
new_instance._set_converters(*converters)
438+
return new_instance
426439

427440

428441
class DefaultPayloadConverter(CompositePayloadConverter):
@@ -1322,7 +1335,6 @@ async def decode_failure(
13221335
return self.failure_converter.from_failure(failure, self.payload_converter)
13231336

13241337
def _with_context(self, context: SerializationContext) -> Self:
1325-
cloned = dataclasses.replace(self)
13261338
payload_converter = self.payload_converter
13271339
payload_codec = self.payload_codec
13281340
failure_converter = self.failure_converter
@@ -1332,6 +1344,16 @@ def _with_context(self, context: SerializationContext) -> Self:
13321344
payload_codec = payload_codec.with_context(context)
13331345
if isinstance(failure_converter, WithSerializationContext):
13341346
failure_converter = failure_converter.with_context(context)
1347+
if all(
1348+
new == orig
1349+
for new, orig in [
1350+
(payload_converter, self.payload_converter),
1351+
(payload_codec, self.payload_codec),
1352+
(failure_converter, self.failure_converter),
1353+
]
1354+
):
1355+
return self
1356+
cloned = dataclasses.replace(self)
13351357
object.__setattr__(cloned, "payload_converter", payload_converter)
13361358
object.__setattr__(cloned, "payload_codec", payload_codec)
13371359
object.__setattr__(cloned, "failure_converter", failure_converter)

0 commit comments

Comments
 (0)