|
24 | 24 | import temporalio.api.common.v1 |
25 | 25 | import temporalio.api.failure.v1 |
26 | 26 | from temporalio import activity, workflow |
27 | | -from temporalio.client import Client, WorkflowFailureError, WorkflowUpdateFailedError |
| 27 | +from temporalio.client import ( |
| 28 | + AsyncActivityHandle, |
| 29 | + Client, |
| 30 | + WorkflowFailureError, |
| 31 | + WorkflowUpdateFailedError, |
| 32 | +) |
28 | 33 | from temporalio.common import RetryPolicy |
29 | 34 | from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter |
30 | 35 | from temporalio.converter import ( |
@@ -610,6 +615,60 @@ async def test_async_activity_completion_payload_conversion( |
610 | 615 | ] |
611 | 616 |
|
612 | 617 |
|
| 618 | +class MyAsyncActivityHandle(AsyncActivityHandle): |
| 619 | + def my_method(self) -> None: |
| 620 | + pass |
| 621 | + |
| 622 | + |
| 623 | +class MyAsyncActivityHandleWithOverriddenConstructor(AsyncActivityHandle): |
| 624 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 625 | + super().__init__(*args, **kwargs) |
| 626 | + |
| 627 | + def my_method(self) -> None: |
| 628 | + pass |
| 629 | + |
| 630 | + |
| 631 | +def test_subclassed_async_activity_handle(client: Client): |
| 632 | + activity_context = ActivitySerializationContext( |
| 633 | + namespace="default", |
| 634 | + workflow_id="workflow-id", |
| 635 | + workflow_type="workflow-type", |
| 636 | + activity_type="activity-type", |
| 637 | + activity_task_queue="activity-task-queue", |
| 638 | + is_local=False, |
| 639 | + ) |
| 640 | + handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token") |
| 641 | + # This works because the data converter does not use context so AsyncActivityHandle.with_context |
| 642 | + # returns self |
| 643 | + assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle) |
| 644 | + |
| 645 | + # This time the data converter uses context so AsyncActivityHandle.with_context attempts to |
| 646 | + # return a new instance of the user's subclass. It works, because they have not overridden the |
| 647 | + # constructor. |
| 648 | + client_config = client.config() |
| 649 | + client_config["data_converter"] = dataclasses.replace( |
| 650 | + DataConverter.default, |
| 651 | + payload_converter_class=SerializationContextCompositePayloadConverter, |
| 652 | + ) |
| 653 | + client = Client(**client_config) |
| 654 | + handle = MyAsyncActivityHandle(client=client, id_or_token=b"task-token") |
| 655 | + assert isinstance(handle.with_context(activity_context), MyAsyncActivityHandle) |
| 656 | + |
| 657 | + # Finally, a user attempts the same but having overridden the constructor. This fails: |
| 658 | + # AsyncActivityHandle.with_context refuses to attempt to create an instance of their subclass. |
| 659 | + handle2 = MyAsyncActivityHandleWithOverriddenConstructor( |
| 660 | + client=client, id_or_token=b"task-token" |
| 661 | + ) |
| 662 | + with pytest.raises( |
| 663 | + TypeError, |
| 664 | + match="you must override with_context to return an instance of your class", |
| 665 | + ): |
| 666 | + assert isinstance( |
| 667 | + handle2.with_context(activity_context), |
| 668 | + MyAsyncActivityHandleWithOverriddenConstructor, |
| 669 | + ) |
| 670 | + |
| 671 | + |
613 | 672 | # Signal test |
614 | 673 |
|
615 | 674 |
|
|
0 commit comments