@@ -1220,7 +1220,7 @@ class PydanticData(BaseModel):
12201220 trace : List [str ] = []
12211221
12221222
1223- class ContextPydanticJSONConverter (
1223+ class PydanticJSONConverterWithContext (
12241224 PydanticJSONPlainPayloadConverter , WithSerializationContext
12251225):
12261226 def __init__ (self ):
@@ -1229,8 +1229,8 @@ def __init__(self):
12291229
12301230 def with_context (
12311231 self , context : Optional [SerializationContext ]
1232- ) -> "ContextPydanticJSONConverter " :
1233- converter = ContextPydanticJSONConverter ()
1232+ ) -> "PydanticJSONConverterWithContext " :
1233+ converter = PydanticJSONConverterWithContext ()
12341234 converter .context = context
12351235 return converter
12361236
@@ -1241,9 +1241,9 @@ def to_payload(self, value: Any) -> Optional[Payload]:
12411241 return super ().to_payload (value )
12421242
12431243
1244- class ContextPydanticConverter (CompositePayloadConverter , WithSerializationContext ):
1244+ class PydanticConverterWithContext (CompositePayloadConverter , WithSerializationContext ):
12451245 def __init__ (self ):
1246- self .json_converter = ContextPydanticJSONConverter ()
1246+ self .json_converter = PydanticJSONConverterWithContext ()
12471247 super ().__init__ (
12481248 * (
12491249 c
@@ -1254,53 +1254,36 @@ def __init__(self):
12541254 )
12551255 self .context : Optional [SerializationContext ] = None
12561256
1257- def with_context (
1258- self , context : Optional [SerializationContext ]
1259- ) -> "ContextPydanticConverter" :
1260- converter = ContextPydanticConverter ()
1261- converter .context = context
1262- # Also set context on all sub-converters
1263- converters : list [EncodingPayloadConverter ] = []
1264- for c in self .converters .values ():
1265- if isinstance (c , WithSerializationContext ):
1266- converters .append (c .with_context (context ))
1267- else :
1268- converters .append (c )
1269- CompositePayloadConverter .__init__ (converter , * converters )
1270- return converter
1271-
12721257
12731258@workflow .defn
12741259class PydanticContextWorkflow :
12751260 @workflow .run
12761261 async def run (self , data : PydanticData ) -> PydanticData :
1277- data .value += "_processed"
12781262 return data
12791263
12801264
12811265async def test_pydantic_converter_with_context (client : Client ):
12821266 wf_id = str (uuid .uuid4 ())
12831267 task_queue = str (uuid .uuid4 ())
12841268
1285- test_client = Client (
1286- client .service_client ,
1287- namespace = client .namespace ,
1288- data_converter = DataConverter (
1289- payload_converter_class = ContextPydanticConverter ,
1290- ),
1269+ client_config = client .config ()
1270+ client_config ["data_converter" ] = dataclasses .replace (
1271+ DataConverter .default ,
1272+ payload_converter_class = PydanticConverterWithContext ,
12911273 )
1274+ client = Client (** client_config )
1275+
12921276 async with Worker (
1293- test_client ,
1277+ client ,
12941278 task_queue = task_queue ,
12951279 workflows = [PydanticContextWorkflow ],
12961280 ):
1297- result = await test_client .execute_workflow (
1281+ result = await client .execute_workflow (
12981282 PydanticContextWorkflow .run ,
12991283 PydanticData (value = "test" ),
13001284 id = wf_id ,
13011285 task_queue = task_queue ,
13021286 )
1303- assert result .value == "test_processed"
13041287 assert f"wf_{ wf_id } " in result .trace
13051288
13061289
0 commit comments