Skip to content

Commit 86c9ea5

Browse files
committed
Clean up pydantic test
1 parent 73c7f47 commit 86c9ea5

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

tests/test_serialization_context.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ class PydanticData(BaseModel):
12201220
trace: List[str] = []
12211221

12221222

1223-
class ContextPydanticJSONConverter(
1223+
class PydanticJSONConverterWithContext(
12241224
PydanticJSONPlainPayloadConverter, WithSerializationContext
12251225
):
12261226
def __init__(self):
@@ -1229,8 +1229,8 @@ def __init__(self):
12291229

12301230
def with_context(
12311231
self, context: Optional[SerializationContext]
1232-
) -> "ContextPydanticJSONConverter":
1233-
converter = ContextPydanticJSONConverter()
1232+
) -> "PydanticJSONConverterWithContext":
1233+
converter = PydanticJSONConverterWithContext()
12341234
converter.context = context
12351235
return converter
12361236

@@ -1241,9 +1241,9 @@ def to_payload(self, value: Any) -> Optional[Payload]:
12411241
return super().to_payload(value)
12421242

12431243

1244-
class ContextPydanticConverter(CompositePayloadConverter, WithSerializationContext):
1244+
class PydanticConverterWithContext(CompositePayloadConverter, WithSerializationContext):
12451245
def __init__(self):
1246-
self.json_converter = ContextPydanticJSONConverter()
1246+
self.json_converter = PydanticJSONConverterWithContext()
12471247
super().__init__(
12481248
*(
12491249
c
@@ -1254,53 +1254,36 @@ def __init__(self):
12541254
)
12551255
self.context: Optional[SerializationContext] = None
12561256

1257-
def with_context(
1258-
self, context: Optional[SerializationContext]
1259-
) -> "ContextPydanticConverter":
1260-
converter = ContextPydanticConverter()
1261-
converter.context = context
1262-
# Also set context on all sub-converters
1263-
converters: list[EncodingPayloadConverter] = []
1264-
for c in self.converters.values():
1265-
if isinstance(c, WithSerializationContext):
1266-
converters.append(c.with_context(context))
1267-
else:
1268-
converters.append(c)
1269-
CompositePayloadConverter.__init__(converter, *converters)
1270-
return converter
1271-
12721257

12731258
@workflow.defn
12741259
class PydanticContextWorkflow:
12751260
@workflow.run
12761261
async def run(self, data: PydanticData) -> PydanticData:
1277-
data.value += "_processed"
12781262
return data
12791263

12801264

12811265
async def test_pydantic_converter_with_context(client: Client):
12821266
wf_id = str(uuid.uuid4())
12831267
task_queue = str(uuid.uuid4())
12841268

1285-
test_client = Client(
1286-
client.service_client,
1287-
namespace=client.namespace,
1288-
data_converter=DataConverter(
1289-
payload_converter_class=ContextPydanticConverter,
1290-
),
1269+
client_config = client.config()
1270+
client_config["data_converter"] = dataclasses.replace(
1271+
DataConverter.default,
1272+
payload_converter_class=PydanticConverterWithContext,
12911273
)
1274+
client = Client(**client_config)
1275+
12921276
async with Worker(
1293-
test_client,
1277+
client,
12941278
task_queue=task_queue,
12951279
workflows=[PydanticContextWorkflow],
12961280
):
1297-
result = await test_client.execute_workflow(
1281+
result = await client.execute_workflow(
12981282
PydanticContextWorkflow.run,
12991283
PydanticData(value="test"),
13001284
id=wf_id,
13011285
task_queue=task_queue,
13021286
)
1303-
assert result.value == "test_processed"
13041287
assert f"wf_{wf_id}" in result.trace
13051288

13061289

0 commit comments

Comments
 (0)