Skip to content

Commit 14996ca

Browse files
committed
Implement WithSerializationContext on AsyncActivityHandle
1 parent 9c7052e commit 14996ca

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

temporalio/client.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@
6363
import temporalio.workflow
6464
from temporalio.activity import ActivityCancellationDetails
6565
from temporalio.converter import (
66-
ActivitySerializationContext,
6766
DataConverter,
67+
SerializationContext,
68+
WithSerializationContext,
6869
WorkflowSerializationContext,
6970
)
7071
from temporalio.service import (
@@ -2732,15 +2733,19 @@ class AsyncActivityIDReference:
27322733
activity_id: str
27332734

27342735

2735-
class AsyncActivityHandle:
2736+
class AsyncActivityHandle(WithSerializationContext):
27362737
"""Handle representing an external activity for completion and heartbeat."""
27372738

27382739
def __init__(
2739-
self, client: Client, id_or_token: Union[AsyncActivityIDReference, bytes]
2740+
self,
2741+
client: Client,
2742+
id_or_token: Union[AsyncActivityIDReference, bytes],
2743+
data_converter_override: Optional[DataConverter] = None,
27402744
) -> None:
27412745
"""Create an async activity handle."""
27422746
self._client = client
27432747
self._id_or_token = id_or_token
2748+
self._data_converter_override = data_converter_override
27442749

27452750
async def heartbeat(
27462751
self,
@@ -2762,6 +2767,7 @@ async def heartbeat(
27622767
details=details,
27632768
rpc_metadata=rpc_metadata,
27642769
rpc_timeout=rpc_timeout,
2770+
data_converter_override=self._data_converter_override,
27652771
),
27662772
)
27672773

@@ -2786,6 +2792,7 @@ async def complete(
27862792
result=result,
27872793
rpc_metadata=rpc_metadata,
27882794
rpc_timeout=rpc_timeout,
2795+
data_converter_override=self._data_converter_override,
27892796
),
27902797
)
27912798

@@ -2813,6 +2820,7 @@ async def fail(
28132820
last_heartbeat_details=last_heartbeat_details,
28142821
rpc_metadata=rpc_metadata,
28152822
rpc_timeout=rpc_timeout,
2823+
data_converter_override=self._data_converter_override,
28162824
),
28172825
)
28182826

@@ -2836,9 +2844,24 @@ async def report_cancellation(
28362844
details=details,
28372845
rpc_metadata=rpc_metadata,
28382846
rpc_timeout=rpc_timeout,
2847+
data_converter_override=self._data_converter_override,
28392848
),
28402849
)
28412850

2851+
def with_context(self, context: SerializationContext) -> AsyncActivityHandle:
2852+
"""Create a new AsyncActivityHandle with a different serialization context.
2853+
2854+
Payloads received by the activity will be decoded and deserialized using a data converter
2855+
with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data
2856+
converter that makes use of this context then you can use this method to supply matching
2857+
context data to the data converter used to serialize and encode the outbound payloads.
2858+
"""
2859+
return AsyncActivityHandle(
2860+
self._client,
2861+
self._id_or_token,
2862+
self._client.data_converter._with_context(context),
2863+
)
2864+
28422865

28432866
@dataclass
28442867
class WorkflowExecution:
@@ -5486,6 +5509,7 @@ class HeartbeatAsyncActivityInput:
54865509
details: Sequence[Any]
54875510
rpc_metadata: Mapping[str, Union[str, bytes]]
54885511
rpc_timeout: Optional[timedelta]
5512+
data_converter_override: Optional[DataConverter] = None
54895513

54905514

54915515
@dataclass
@@ -5496,6 +5520,7 @@ class CompleteAsyncActivityInput:
54965520
result: Optional[Any]
54975521
rpc_metadata: Mapping[str, Union[str, bytes]]
54985522
rpc_timeout: Optional[timedelta]
5523+
data_converter_override: Optional[DataConverter] = None
54995524

55005525

55015526
@dataclass
@@ -5507,6 +5532,7 @@ class FailAsyncActivityInput:
55075532
last_heartbeat_details: Sequence[Any]
55085533
rpc_metadata: Mapping[str, Union[str, bytes]]
55095534
rpc_timeout: Optional[timedelta]
5535+
data_converter_override: Optional[DataConverter] = None
55105536

55115537

55125538
@dataclass
@@ -5517,6 +5543,7 @@ class ReportCancellationAsyncActivityInput:
55175543
details: Sequence[Any]
55185544
rpc_metadata: Mapping[str, Union[str, bytes]]
55195545
rpc_timeout: Optional[timedelta]
5546+
data_converter_override: Optional[DataConverter] = None
55205547

55215548

55225549
@dataclass
@@ -6418,7 +6445,7 @@ async def _start_workflow_update_with_start(
64186445
async def heartbeat_async_activity(
64196446
self, input: HeartbeatAsyncActivityInput
64206447
) -> None:
6421-
data_converter = self._async_activity_data_converter(input.id_or_token)
6448+
data_converter = input.data_converter_override or self._client.data_converter
64226449
details = (
64236450
None
64246451
if not input.details
@@ -6473,7 +6500,7 @@ async def heartbeat_async_activity(
64736500
)
64746501

64756502
async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
6476-
data_converter = self._async_activity_data_converter(input.id_or_token)
6503+
data_converter = input.data_converter_override or self._client.data_converter
64776504
result = (
64786505
None
64796506
if input.result is temporalio.common._arg_unset
@@ -6507,7 +6534,7 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No
65076534
)
65086535

65096536
async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
6510-
data_converter = self._async_activity_data_converter(input.id_or_token)
6537+
data_converter = input.data_converter_override or self._client.data_converter
65116538

65126539
failure = temporalio.api.failure.v1.Failure()
65136540
await data_converter.encode_failure(input.error, failure)
@@ -6548,7 +6575,7 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None:
65486575
async def report_cancellation_async_activity(
65496576
self, input: ReportCancellationAsyncActivityInput
65506577
) -> None:
6551-
data_converter = self._async_activity_data_converter(input.id_or_token)
6578+
data_converter = input.data_converter_override or self._client.data_converter
65526579
details = (
65536580
None
65546581
if not input.details
@@ -6581,24 +6608,6 @@ async def report_cancellation_async_activity(
65816608
timeout=input.rpc_timeout,
65826609
)
65836610

6584-
def _async_activity_data_converter(
6585-
self, id_or_token: Union[AsyncActivityIDReference, bytes]
6586-
) -> DataConverter:
6587-
return self._client.data_converter._with_context(
6588-
ActivitySerializationContext(
6589-
namespace=self._client.namespace,
6590-
workflow_id=(
6591-
id_or_token.workflow_id
6592-
if isinstance(id_or_token, AsyncActivityIDReference)
6593-
else ""
6594-
),
6595-
workflow_type="",
6596-
activity_type="",
6597-
activity_task_queue="",
6598-
is_local=False,
6599-
)
6600-
)
6601-
66026611
### Schedule calls
66036612

66046613
async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle:

0 commit comments

Comments
 (0)