@@ -1743,3 +1743,84 @@ async def test_pydantic_converter_with_context(client: Client):
17431743 task_queue = task_queue ,
17441744 )
17451745 assert f"wf_{ wf_id } " in result .trace
1746+
1747+
1748+ # Test customized DefaultPayloadConverter
1749+
1750+ # The SDK's CompositePayloadConverter comes with a with_context implementation that ensures that its
1751+ # component EncodingPayloadConverters will be replaced with the results of calling with_context() on
1752+ # them, if they support with_context (this happens when we call data_converter._with_context). In
1753+ # this test, the user has subclassed CompositePayloadConverter. The test confirms that the
1754+ # CompositePayloadConverter's with_context yields an instance of the user's subclass.
1755+
1756+
1757+ class UserMethodCalledError (Exception ):
1758+ pass
1759+
1760+
1761+ class CustomEncodingPayloadConverter (
1762+ JSONPlainPayloadConverter , WithSerializationContext
1763+ ):
1764+ @property
1765+ def encoding (self ) -> str :
1766+ return "custom-encoding-that-does-not-clash-with-default-converters"
1767+
1768+ def __init__ (self ):
1769+ super ().__init__ ()
1770+ self .context : Optional [SerializationContext ] = None
1771+
1772+ def with_context (
1773+ self , context : Optional [SerializationContext ]
1774+ ) -> CustomEncodingPayloadConverter :
1775+ converter = CustomEncodingPayloadConverter ()
1776+ converter .context = context
1777+ return converter
1778+
1779+
1780+ class CustomPayloadConverter (CompositePayloadConverter ):
1781+ def __init__ (self ):
1782+ # Add a context-aware EncodingPayloadConverter so that
1783+ # CompositePayloadConverter.with_context is forced to construct and return a new instance.
1784+ super ().__init__ (
1785+ CustomEncodingPayloadConverter (),
1786+ * DefaultPayloadConverter .default_encoding_payload_converters ,
1787+ )
1788+
1789+ def to_payloads (
1790+ self , values : Sequence [Any ]
1791+ ) -> List [temporalio .api .common .v1 .Payload ]:
1792+ raise UserMethodCalledError
1793+
1794+ def from_payloads (
1795+ self ,
1796+ payloads : Sequence [temporalio .api .common .v1 .Payload ],
1797+ type_hints : Optional [List [Type ]] = None ,
1798+ ) -> List [Any ]:
1799+ raise NotImplementedError
1800+
1801+
1802+ async def test_user_customization_of_default_payload_converter (
1803+ client : Client ,
1804+ ):
1805+ wf_id = str (uuid .uuid4 ())
1806+ task_queue = str (uuid .uuid4 ())
1807+
1808+ client_config = client .config ()
1809+ client_config ["data_converter" ] = dataclasses .replace (
1810+ DataConverter .default ,
1811+ payload_converter_class = CustomPayloadConverter ,
1812+ )
1813+ client = Client (** client_config )
1814+
1815+ async with Worker (
1816+ client ,
1817+ task_queue = task_queue ,
1818+ workflows = [EchoWorkflow ],
1819+ ):
1820+ with pytest .raises (UserMethodCalledError ):
1821+ await client .execute_workflow (
1822+ EchoWorkflow .run ,
1823+ TraceData (),
1824+ id = wf_id ,
1825+ task_queue = task_queue ,
1826+ )
0 commit comments