Skip to content

Commit 135fecb

Browse files
committed
wire async activity completion context
1 parent a204dae commit 135fecb

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

temporalio/client.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@
6262
import temporalio.service
6363
import temporalio.workflow
6464
from temporalio.activity import ActivityCancellationDetails
65-
from temporalio.converter import WorkflowSerializationContext
65+
from temporalio.converter import (
66+
ActivitySerializationContext,
67+
DataConverter,
68+
WorkflowSerializationContext,
69+
)
6670
from temporalio.service import (
6771
HttpConnectProxyConfig,
6872
KeepAliveConfig,
@@ -6391,10 +6395,11 @@ async def _start_workflow_update_with_start(
63916395
async def heartbeat_async_activity(
63926396
self, input: HeartbeatAsyncActivityInput
63936397
) -> None:
6398+
data_converter = self._async_activity_data_converter(input.id_or_token)
63946399
details = (
63956400
None
63966401
if not input.details
6397-
else await self._client.data_converter.encode_wrapper(input.details)
6402+
else await data_converter.encode_wrapper(input.details)
63986403
)
63996404
if isinstance(input.id_or_token, AsyncActivityIDReference):
64006405
resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id(
@@ -6445,10 +6450,11 @@ async def heartbeat_async_activity(
64456450
)
64466451

64476452
async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
6453+
data_converter = self._async_activity_data_converter(input.id_or_token)
64486454
result = (
64496455
None
64506456
if input.result is temporalio.common._arg_unset
6451-
else await self._client.data_converter.encode_wrapper([input.result])
6457+
else await data_converter.encode_wrapper([input.result])
64526458
)
64536459
if isinstance(input.id_or_token, AsyncActivityIDReference):
64546460
await self._client.workflow_service.respond_activity_task_completed_by_id(
@@ -6478,14 +6484,14 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No
64786484
)
64796485

64806486
async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
6487+
data_converter = self._async_activity_data_converter(input.id_or_token)
6488+
64816489
failure = temporalio.api.failure.v1.Failure()
6482-
await self._client.data_converter.encode_failure(input.error, failure)
6490+
await data_converter.encode_failure(input.error, failure)
64836491
last_heartbeat_details = (
64846492
None
64856493
if not input.last_heartbeat_details
6486-
else await self._client.data_converter.encode_wrapper(
6487-
input.last_heartbeat_details
6488-
)
6494+
else await data_converter.encode_wrapper(input.last_heartbeat_details)
64896495
)
64906496
if isinstance(input.id_or_token, AsyncActivityIDReference):
64916497
await self._client.workflow_service.respond_activity_task_failed_by_id(
@@ -6519,10 +6525,11 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
65196525
async def report_cancellation_async_activity(
65206526
self, input: ReportCancellationAsyncActivityInput
65216527
) -> None:
6528+
data_converter = self._async_activity_data_converter(input.id_or_token)
65226529
details = (
65236530
None
65246531
if not input.details
6525-
else await self._client.data_converter.encode_wrapper(input.details)
6532+
else await data_converter.encode_wrapper(input.details)
65266533
)
65276534
if isinstance(input.id_or_token, AsyncActivityIDReference):
65286535
await self._client.workflow_service.respond_activity_task_canceled_by_id(
@@ -6551,6 +6558,23 @@ async def report_cancellation_async_activity(
65516558
timeout=input.rpc_timeout,
65526559
)
65536560

6561+
def _async_activity_data_converter(
6562+
self, id_or_token: Union[AsyncActivityIDReference, bytes]
6563+
) -> DataConverter:
6564+
context = ActivitySerializationContext(
6565+
namespace=self._client.namespace,
6566+
workflow_id=(
6567+
id_or_token.workflow_id
6568+
if isinstance(id_or_token, AsyncActivityIDReference)
6569+
else ""
6570+
),
6571+
workflow_type="",
6572+
activity_type="",
6573+
activity_task_queue="",
6574+
is_local=False,
6575+
)
6576+
return self._client.data_converter._with_context(context)
6577+
65546578
### Schedule calls
65556579

65566580
async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle:

temporalio/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
Mapping,
2929
NewType,
3030
Optional,
31-
Self,
3231
Sequence,
3332
Tuple,
3433
Type,
@@ -44,6 +43,7 @@
4443
import google.protobuf.symbol_database
4544
import nexusrpc
4645
import typing_extensions
46+
from typing_extensions import Self
4747

4848
import temporalio.api.common.v1
4949
import temporalio.api.enums.v1

temporalio/worker/_workflow_instance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3187,7 +3187,7 @@ def _resolve_failure(self, err: BaseException) -> None:
31873187
self._result_fut.set_result(None)
31883188

31893189
def _apply_schedule_command(self) -> None:
3190-
payload = self._payload_converter.to_payload(self._input.input)
3190+
payload = self._payload_converter.to_payload(self._input.input) # type: ignore TODO
31913191
command = self._instance._add_command()
31923192
v = command.schedule_nexus_operation
31933193
v.seq = self._seq

0 commit comments

Comments
 (0)