@@ -349,6 +349,9 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None:
349349 Args:
350350 converters: Payload converters to delegate to, in order.
351351 """
352+ self ._set_converters (* converters )
353+
354+ def _set_converters (self , * converters : EncodingPayloadConverter ) -> None :
352355 self .converters = {c .encoding .encode (): c for c in converters }
353356
354357 def to_payloads (
@@ -413,16 +416,26 @@ def from_payloads(
413416 ) from err
414417 return values
415418
416- def with_context (self , context : SerializationContext ) -> CompositePayloadConverter :
417- """Return a new instance with context set on the component converters"""
418- return CompositePayloadConverter (
419- * (
420- c .with_context (context )
421- if isinstance (c , WithSerializationContext )
422- else c
423- for c in self .converters .values ()
424- )
425- )
419+ def with_context (self , context : SerializationContext ) -> Self :
420+ """Return a new instance with context set on the component converters.
421+
422+ If none of the component converters support with_context, return self.
423+ """
424+ converters : list [EncodingPayloadConverter ] = []
425+ any_with_context = False
426+ for c in self .converters .values ():
427+ if isinstance (c , WithSerializationContext ):
428+ converters .append (c .with_context (context ))
429+ any_with_context = True
430+ else :
431+ converters .append (c )
432+
433+ if not any_with_context :
434+ return self
435+
436+ new_instance = type (self )()
437+ new_instance ._set_converters (* converters )
438+ return new_instance
426439
427440
428441class DefaultPayloadConverter (CompositePayloadConverter ):
@@ -1322,7 +1335,6 @@ async def decode_failure(
13221335 return self .failure_converter .from_failure (failure , self .payload_converter )
13231336
13241337 def _with_context (self , context : SerializationContext ) -> Self :
1325- cloned = dataclasses .replace (self )
13261338 payload_converter = self .payload_converter
13271339 payload_codec = self .payload_codec
13281340 failure_converter = self .failure_converter
@@ -1332,6 +1344,16 @@ def _with_context(self, context: SerializationContext) -> Self:
13321344 payload_codec = payload_codec .with_context (context )
13331345 if isinstance (failure_converter , WithSerializationContext ):
13341346 failure_converter = failure_converter .with_context (context )
1347+ if all (
1348+ new == orig
1349+ for new , orig in [
1350+ (payload_converter , self .payload_converter ),
1351+ (payload_codec , self .payload_codec ),
1352+ (failure_converter , self .failure_converter ),
1353+ ]
1354+ ):
1355+ return self
1356+ cloned = dataclasses .replace (self )
13351357 object .__setattr__ (cloned , "payload_converter" , payload_converter )
13361358 object .__setattr__ (cloned , "payload_codec" , payload_codec )
13371359 object .__setattr__ (cloned , "failure_converter" , failure_converter )
0 commit comments