Skip to content

Commit 0201fba

Browse files
committed
Return self in AsyncActivityHandle.with_handle
1 parent b72f2b7 commit 0201fba

File tree

2 files changed

+72
-6
lines changed

2 files changed

+72
-6
lines changed

temporalio/client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import google.protobuf.json_format
4242
import google.protobuf.timestamp_pb2
4343
from google.protobuf.internal.containers import MessageMap
44-
from typing_extensions import Concatenate, Required, TypedDict
44+
from typing_extensions import Concatenate, Required, Self, TypedDict
4545

4646
import temporalio.api.common.v1
4747
import temporalio.api.enums.v1
@@ -2852,17 +2852,24 @@ async def report_cancellation(
28522852
),
28532853
)
28542854

2855-
# TODO(dan): should this return Self (requiring that the user's subclass has the same
2856-
# constructor signature)? CompositePayloadConverter.with_context does.
2857-
def with_context(self, context: SerializationContext) -> AsyncActivityHandle:
2855+
def with_context(self, context: SerializationContext) -> Self:
28582856
"""Create a new AsyncActivityHandle with a different serialization context.
28592857
28602858
Payloads received by the activity will be decoded and deserialized using a data converter
28612859
with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data
28622860
converter that makes use of this context then you can use this method to supply matching
28632861
context data to the data converter used to serialize and encode the outbound payloads.
28642862
"""
2865-
return AsyncActivityHandle(
2863+
data_converter = self._client.data_converter.with_context(context)
2864+
if data_converter == self._client.data_converter:
2865+
return self
2866+
cls = type(self)
2867+
if cls.__init__ is not AsyncActivityHandle.__init__:
2868+
raise TypeError(
2869+
"If you have subclassed AsyncActivityHandle and overridden the __init__ method "
2870+
"then you must override with_context to return an instance of your class."
2871+
)
2872+
return cls(
28662873
self._client,
28672874
self._id_or_token,
28682875
self._client.data_converter.with_context(context),

tests/test_serialization_context.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
import temporalio.api.common.v1
2525
import temporalio.api.failure.v1
2626
from temporalio import activity, workflow
27-
from temporalio.client import Client, WorkflowFailureError, WorkflowUpdateFailedError
27+
from temporalio.client import (
28+
AsyncActivityHandle,
29+
Client,
30+
WorkflowFailureError,
31+
WorkflowUpdateFailedError,
32+
)
2833
from temporalio.common import RetryPolicy
2934
from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter
3035
from temporalio.converter import (
@@ -610,6 +615,60 @@ async def test_async_activity_completion_payload_conversion(
610615
]
611616

612617

618+
class MyAsyncActivityHandle(AsyncActivityHandle):
619+
def my_method(self) -> None:
620+
pass
621+
622+
623+
class MyAsyncActivityHandleWithOverriddenConstructor(AsyncActivityHandle):
624+
def __init__(self, *args: Any, **kwargs: Any) -> None:
625+
super().__init__(*args, **kwargs)
626+
627+
def my_method(self) -> None:
628+
pass
629+
630+
631+
def test_subclassed_async_activity_handle(client: Client):
632+
activity_context = ActivitySerializationContext(
633+
namespace="default",
634+
workflow_id="workflow-id",
635+
workflow_type="workflow-type",
636+
activity_type="activity-type",
637+
activity_task_queue="activity-task-queue",
638+
is_local=False,
639+
)
640+
handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token")
641+
# This works because the data converter does not use context so AsyncActivityHandle.with_context
642+
# returns self
643+
assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle)
644+
645+
# This time the data converter uses context so AsyncActivityHandle.with_context attempts to
646+
# return a new instance of the user's subclass. It works, because they have not overridden the
647+
# constructor.
648+
client_config = client.config()
649+
client_config["data_converter"] = dataclasses.replace(
650+
DataConverter.default,
651+
payload_converter_class=SerializationContextCompositePayloadConverter,
652+
)
653+
client = Client(**client_config)
654+
handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token")
655+
assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle)
656+
657+
# Finally, a user attempts the same but having overridden the constructor. This fails:
658+
# AsyncActivityHandle.with_context refuses to attempt to create an instance of their subclass.
659+
handle2 = MyAsyncActivityHandleWithOverriddenConstructor(
660+
client=client, id_or_token=b"task-token"
661+
)
662+
with pytest.raises(
663+
TypeError,
664+
match="you must override with_context to return an instance of your class",
665+
):
666+
assert isinstance(
667+
handle2.with_context(activity_context),
668+
MyAsyncActivityHandleWithOverriddenConstructor,
669+
)
670+
671+
613672
# Signal test
614673

615674

0 commit comments

Comments
 (0)