Skip to content

Commit e510a29

Browse files
committed
Full pass over client and workflow converter call sites
1 parent d7e13af commit e510a29

File tree

4 files changed

+229
-291
lines changed

4 files changed

+229
-291
lines changed

temporalio/client.py

Lines changed: 102 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,7 +3200,15 @@ async def fetch_next_page(self, *, page_size: Optional[int] = None) -> None:
32003200
timeout=self._input.rpc_timeout,
32013201
)
32023202
self._current_page = [
3203-
WorkflowExecution._from_raw_info(v, self._client.data_converter)
3203+
WorkflowExecution._from_raw_info(
3204+
v,
3205+
self._client.data_converter._with_context(
3206+
WorkflowSerializationContext(
3207+
namespace=self._client.namespace,
3208+
workflow_id=v.execution.workflow_id,
3209+
)
3210+
),
3211+
)
32043212
for v in resp.executions
32053213
]
32063214
self._current_page_index = 0
@@ -4158,37 +4166,47 @@ async def _to_proto(
41584166
priority: Optional[temporalio.api.common.v1.Priority] = None
41594167
if self.priority:
41604168
priority = self.priority._to_proto()
4169+
data_converter = client.data_converter._with_context(
4170+
WorkflowSerializationContext(
4171+
namespace=client.namespace,
4172+
workflow_id=self.id,
4173+
)
4174+
)
41614175
action = temporalio.api.schedule.v1.ScheduleAction(
41624176
start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo(
41634177
workflow_id=self.id,
41644178
workflow_type=temporalio.api.common.v1.WorkflowType(name=self.workflow),
41654179
task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=self.task_queue),
4166-
input=None
4167-
if not self.args
4168-
else temporalio.api.common.v1.Payloads(
4169-
payloads=[
4170-
a
4171-
if isinstance(a, temporalio.api.common.v1.Payload)
4172-
else (await client.data_converter.encode([a]))[0]
4173-
for a in self.args
4174-
]
4180+
input=(
4181+
temporalio.api.common.v1.Payloads(
4182+
payloads=[
4183+
a
4184+
if isinstance(a, temporalio.api.common.v1.Payload)
4185+
else (await data_converter.encode([a]))[0]
4186+
for a in self.args
4187+
]
4188+
)
4189+
if self.args
4190+
else None
41754191
),
41764192
workflow_execution_timeout=execution_timeout,
41774193
workflow_run_timeout=run_timeout,
41784194
workflow_task_timeout=task_timeout,
41794195
retry_policy=retry_policy,
4180-
memo=None
4181-
if not self.memo
4182-
else temporalio.api.common.v1.Memo(
4183-
fields={
4184-
k: v
4185-
if isinstance(v, temporalio.api.common.v1.Payload)
4186-
else (await client.data_converter.encode([v]))[0]
4187-
for k, v in self.memo.items()
4188-
},
4196+
memo=(
4197+
temporalio.api.common.v1.Memo(
4198+
fields={
4199+
k: v
4200+
if isinstance(v, temporalio.api.common.v1.Payload)
4201+
else (await data_converter.encode([v]))[0]
4202+
for k, v in self.memo.items()
4203+
},
4204+
)
4205+
if self.memo
4206+
else None
41894207
),
41904208
user_metadata=await _encode_user_metadata(
4191-
client.data_converter, self.static_summary, self.static_details
4209+
data_converter, self.static_summary, self.static_details
41924210
),
41934211
priority=priority,
41944212
),
@@ -5914,12 +5932,18 @@ async def _build_signal_with_start_workflow_execution_request(
59145932
self, input: StartWorkflowInput
59155933
) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest:
59165934
assert input.start_signal
5935+
data_converter = self._client.data_converter._with_context(
5936+
WorkflowSerializationContext(
5937+
namespace=self._client.namespace,
5938+
workflow_id=input.id,
5939+
)
5940+
)
59175941
req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest(
59185942
signal_name=input.start_signal
59195943
)
59205944
if input.start_signal_args:
59215945
req.signal_input.payloads.extend(
5922-
await self._client.data_converter.encode(input.start_signal_args)
5946+
await data_converter.encode(input.start_signal_args)
59235947
)
59245948
await self._populate_start_workflow_execution_request(req, input)
59255949
return req
@@ -5939,19 +5963,18 @@ async def _populate_start_workflow_execution_request(
59395963
],
59405964
input: Union[StartWorkflowInput, UpdateWithStartStartWorkflowInput],
59415965
) -> None:
5966+
data_converter = self._client.data_converter._with_context(
5967+
WorkflowSerializationContext(
5968+
namespace=self._client.namespace,
5969+
workflow_id=input.id,
5970+
)
5971+
)
59425972
req.namespace = self._client.namespace
59435973
req.workflow_id = input.id
59445974
req.workflow_type.name = input.workflow
59455975
req.task_queue.name = input.task_queue
59465976
if input.args:
5947-
context = temporalio.converter.WorkflowSerializationContext(
5948-
namespace=self._client.namespace, workflow_id=input.id
5949-
)
5950-
req.input.payloads.extend(
5951-
await self._client.data_converter._with_context(context).encode(
5952-
input.args
5953-
)
5954-
)
5977+
req.input.payloads.extend(await data_converter.encode(input.args))
59555978
if input.execution_timeout is not None:
59565979
req.workflow_execution_timeout.FromTimedelta(input.execution_timeout)
59575980
if input.run_timeout is not None:
@@ -5974,15 +5997,13 @@ async def _populate_start_workflow_execution_request(
59745997
req.cron_schedule = input.cron_schedule
59755998
if input.memo is not None:
59765999
for k, v in input.memo.items():
5977-
req.memo.fields[k].CopyFrom(
5978-
(await self._client.data_converter.encode([v]))[0]
5979-
)
6000+
req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0])
59806001
if input.search_attributes is not None:
59816002
temporalio.converter.encode_search_attributes(
59826003
input.search_attributes, req.search_attributes
59836004
)
59846005
metadata = await _encode_user_metadata(
5985-
self._client.data_converter, input.static_summary, input.static_details
6006+
data_converter, input.static_summary, input.static_details
59866007
)
59876008
if metadata is not None:
59886009
req.user_metadata.CopyFrom(metadata)
@@ -6028,7 +6049,12 @@ async def describe_workflow(
60286049
metadata=input.rpc_metadata,
60296050
timeout=input.rpc_timeout,
60306051
),
6031-
self._client.data_converter,
6052+
self._client.data_converter._with_context(
6053+
WorkflowSerializationContext(
6054+
namespace=self._client.namespace,
6055+
workflow_id=input.id,
6056+
)
6057+
),
60326058
)
60336059

60346060
def fetch_workflow_history_events(
@@ -6057,6 +6083,12 @@ async def count_workflows(
60576083
)
60586084

60596085
async def query_workflow(self, input: QueryWorkflowInput) -> Any:
6086+
data_converter = self._client.data_converter._with_context(
6087+
WorkflowSerializationContext(
6088+
namespace=self._client.namespace,
6089+
workflow_id=input.id,
6090+
)
6091+
)
60606092
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
60616093
namespace=self._client.namespace,
60626094
execution=temporalio.api.common.v1.WorkflowExecution(
@@ -6071,14 +6103,8 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
60716103
)
60726104
req.query.query_type = input.query
60736105
if input.args:
6074-
context = WorkflowSerializationContext(
6075-
namespace=self._client.namespace,
6076-
workflow_id=input.id,
6077-
)
60786106
req.query.query_args.payloads.extend(
6079-
await self._client.data_converter._with_context(context).encode(
6080-
input.args
6081-
)
6107+
await data_converter.encode(input.args)
60826108
)
60836109
if input.headers is not None:
60846110
await self._apply_headers(input.headers, req.query.header.fields)
@@ -6102,20 +6128,20 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
61026128
if not resp.query_result.payloads:
61036129
return None
61046130
type_hints = [input.ret_type] if input.ret_type else None
6105-
context = WorkflowSerializationContext(
6106-
namespace=self._client.namespace,
6107-
workflow_id=input.id,
6108-
)
6109-
results = await self._client.data_converter._with_context(context).decode(
6110-
resp.query_result.payloads, type_hints
6111-
)
6131+
results = await data_converter.decode(resp.query_result.payloads, type_hints)
61126132
if not results:
61136133
return None
61146134
elif len(results) > 1:
61156135
warnings.warn(f"Expected single query result, got {len(results)}")
61166136
return results[0]
61176137

61186138
async def signal_workflow(self, input: SignalWorkflowInput) -> None:
6139+
data_converter = self._client.data_converter._with_context(
6140+
WorkflowSerializationContext(
6141+
namespace=self._client.namespace,
6142+
workflow_id=input.id,
6143+
)
6144+
)
61196145
req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest(
61206146
namespace=self._client.namespace,
61216147
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
@@ -6127,22 +6153,20 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
61276153
request_id=str(uuid.uuid4()),
61286154
)
61296155
if input.args:
6130-
context = temporalio.converter.WorkflowSerializationContext(
6131-
namespace=self._client.namespace,
6132-
workflow_id=input.id,
6133-
)
6134-
req.input.payloads.extend(
6135-
await self._client.data_converter._with_context(context).encode(
6136-
input.args
6137-
)
6138-
)
6156+
req.input.payloads.extend(await data_converter.encode(input.args))
61396157
if input.headers is not None:
61406158
await self._apply_headers(input.headers, req.header.fields)
61416159
await self._client.workflow_service.signal_workflow_execution(
61426160
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
61436161
)
61446162

61456163
async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
6164+
data_converter = self._client.data_converter._with_context(
6165+
WorkflowSerializationContext(
6166+
namespace=self._client.namespace,
6167+
workflow_id=input.id,
6168+
)
6169+
)
61466170
req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest(
61476171
namespace=self._client.namespace,
61486172
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
@@ -6154,9 +6178,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
61546178
first_execution_run_id=input.first_execution_run_id or "",
61556179
)
61566180
if input.args:
6157-
req.details.payloads.extend(
6158-
await self._client.data_converter.encode(input.args)
6159-
)
6181+
req.details.payloads.extend(await data_converter.encode(input.args))
61606182
await self._client.workflow_service.terminate_workflow_execution(
61616183
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
61626184
)
@@ -6213,6 +6235,12 @@ async def _build_update_workflow_execution_request(
62136235
input: Union[StartWorkflowUpdateInput, UpdateWithStartUpdateWorkflowInput],
62146236
workflow_id: str,
62156237
) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest:
6238+
data_converter = self._client.data_converter._with_context(
6239+
WorkflowSerializationContext(
6240+
namespace=self._client.namespace,
6241+
workflow_id=workflow_id,
6242+
)
6243+
)
62166244
run_id, first_execution_run_id = (
62176245
(
62186246
input.run_id,
@@ -6244,14 +6272,8 @@ async def _build_update_workflow_execution_request(
62446272
),
62456273
)
62466274
if input.args:
6247-
context = temporalio.converter.WorkflowSerializationContext(
6248-
namespace=self._client.namespace,
6249-
workflow_id=workflow_id,
6250-
)
62516275
req.request.input.args.payloads.extend(
6252-
await self._client.data_converter._with_context(context).encode(
6253-
input.args
6254-
)
6276+
await data_converter.encode(input.args)
62556277
)
62566278
if input.headers is not None:
62576279
await self._apply_headers(input.headers, req.request.input.header.fields)
@@ -6561,19 +6583,20 @@ async def report_cancellation_async_activity(
65616583
def _async_activity_data_converter(
65626584
self, id_or_token: Union[AsyncActivityIDReference, bytes]
65636585
) -> DataConverter:
6564-
context = ActivitySerializationContext(
6565-
namespace=self._client.namespace,
6566-
workflow_id=(
6567-
id_or_token.workflow_id
6568-
if isinstance(id_or_token, AsyncActivityIDReference)
6569-
else ""
6570-
),
6571-
workflow_type="",
6572-
activity_type="",
6573-
activity_task_queue="",
6574-
is_local=False,
6586+
return self._client.data_converter._with_context(
6587+
ActivitySerializationContext(
6588+
namespace=self._client.namespace,
6589+
workflow_id=(
6590+
id_or_token.workflow_id
6591+
if isinstance(id_or_token, AsyncActivityIDReference)
6592+
else ""
6593+
),
6594+
workflow_type="",
6595+
activity_type="",
6596+
activity_task_queue="",
6597+
is_local=False,
6598+
)
65756599
)
6576-
return self._client.data_converter._with_context(context)
65776600

65786601
### Schedule calls
65796602

temporalio/converter.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,6 @@ class ActivitySerializationContext(SerializationContext):
119119
is_local: bool
120120

121121

122-
@dataclass(frozen=True)
123-
class NexusOperationSerializationContext(SerializationContext):
124-
service: str
125-
operation: str
126-
127-
128122
class WithSerializationContext(ABC):
129123
"""Interface for objects that can use serialization context.
130124
@@ -396,6 +390,7 @@ def from_payloads(
396390
return values
397391

398392
def with_context(self, context: Optional[SerializationContext]) -> Self:
393+
"""Return a new instance with the given context."""
399394
instance = type(self).__new__(type(self))
400395
converters = [
401396
c.with_context(context) if isinstance(c, WithSerializationContext) else c
@@ -1302,22 +1297,16 @@ async def decode_failure(
13021297
return self.failure_converter.from_failure(failure, self.payload_converter)
13031298

13041299
def _with_context(self, context: Optional[SerializationContext]) -> Self:
1305-
payload_converter = (
1306-
self.payload_converter.with_context(context)
1307-
if isinstance(self.payload_converter, WithSerializationContext)
1308-
else self.payload_converter
1309-
)
1310-
payload_codec = (
1311-
self.payload_codec.with_context(context)
1312-
if isinstance(self.payload_codec, WithSerializationContext)
1313-
else self.payload_codec
1314-
)
1315-
failure_converter = (
1316-
self.failure_converter.with_context(context)
1317-
if isinstance(self.failure_converter, WithSerializationContext)
1318-
else self.failure_converter
1319-
)
13201300
cloned = dataclasses.replace(self)
1301+
payload_converter = self.payload_converter
1302+
payload_codec = self.payload_codec
1303+
failure_converter = self.failure_converter
1304+
if isinstance(payload_converter, WithSerializationContext):
1305+
payload_converter = payload_converter.with_context(context)
1306+
if isinstance(payload_codec, WithSerializationContext):
1307+
payload_codec = payload_codec.with_context(context)
1308+
if isinstance(failure_converter, WithSerializationContext):
1309+
failure_converter = failure_converter.with_context(context)
13211310
object.__setattr__(cloned, "payload_converter", payload_converter)
13221311
object.__setattr__(cloned, "payload_codec", payload_codec)
13231312
object.__setattr__(cloned, "failure_converter", failure_converter)

0 commit comments

Comments
 (0)