1515import warnings
1616from abc import ABC , abstractmethod
1717from contextlib import contextmanager
18- from dataclasses import dataclass
18+ from dataclasses import dataclass , field
1919from datetime import datetime , timedelta , timezone
2020from typing import (
2121 Any ,
@@ -216,7 +216,13 @@ def _cancel(
216216 warnings .warn (f"Cannot find activity to cancel for token { task_token !r} " )
217217 return
218218 logger .debug ("Cancelling activity %s, reason: %s" , task_token , cancel .reason )
219- activity .cancel (cancelled_by_request = True )
219+ activity .cancellation_details .details = (
220+ temporalio .activity .ActivityCancellationDetails ._from_proto (cancel .details )
221+ )
222+ activity .cancel (
223+ cancelled_by_request = cancel .details .is_cancelled
224+ or cancel .details .is_worker_shutdown
225+ )
220226
221227 def _heartbeat (self , task_token : bytes , * details : Any ) -> None :
222228 # We intentionally make heartbeating non-async, but since the data
@@ -303,6 +309,24 @@ async def _run_activity(
303309 await self ._data_converter .encode_failure (
304310 err , completion .result .failed .failure
305311 )
312+ elif (
313+ isinstance (
314+ err ,
315+ (asyncio .CancelledError , temporalio .exceptions .CancelledError ),
316+ )
317+ and running_activity .cancellation_details .details
318+ and running_activity .cancellation_details .details .paused
319+ ):
320+ temporalio .activity .logger .warning (
321+ f"Completing as failure due to unhandled cancel error produced by activity pause" ,
322+ )
323+ await self ._data_converter .encode_failure (
324+ temporalio .exceptions .ApplicationError (
325+ type = "ActivityPause" ,
326+ message = "Unhandled activity cancel error produced by activity pause" ,
327+ ),
328+ completion .result .failed .failure ,
329+ )
306330 elif (
307331 isinstance (
308332 err ,
@@ -336,7 +360,6 @@ async def _run_activity(
336360 await self ._data_converter .encode_failure (
337361 err , completion .result .failed .failure
338362 )
339-
340363 # For broken executors, we have to fail the entire worker
341364 if isinstance (err , concurrent .futures .BrokenExecutor ):
342365 self ._fail_worker_exception_queue .put_nowait (err )
@@ -524,6 +547,7 @@ async def _execute_activity(
524547 else running_activity .cancel_thread_raiser .shielded ,
525548 payload_converter_class_or_instance = self ._data_converter .payload_converter ,
526549 runtime_metric_meter = None if sync_non_threaded else self ._metric_meter ,
550+ cancellation_details = running_activity .cancellation_details ,
527551 )
528552 )
529553 temporalio .activity .logger .debug ("Starting activity" )
@@ -570,6 +594,9 @@ class _RunningActivity:
570594 done : bool = False
571595 cancelled_by_request : bool = False
572596 cancelled_due_to_heartbeat_error : Optional [Exception ] = None
597+ cancellation_details : temporalio .activity ._ActivityCancellationDetailsHolder = (
598+ field (default_factory = temporalio .activity ._ActivityCancellationDetailsHolder )
599+ )
573600
574601 def cancel (
575602 self ,
@@ -659,6 +686,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
659686 # can set the initializer on the executor).
660687 ctx = temporalio .activity ._Context .current ()
661688 info = ctx .info ()
689+ cancellation_details = ctx .cancellation_details
662690
663691 # Heartbeat calls internally use a data converter which is async so
664692 # they need to be called on the event loop
@@ -717,6 +745,7 @@ async def heartbeat_with_context(*details: Any) -> None:
717745 worker_shutdown_event .thread_event ,
718746 payload_converter_class_or_instance ,
719747 ctx .runtime_metric_meter ,
748+ cancellation_details ,
720749 input .fn ,
721750 * input .args ,
722751 ]
@@ -732,7 +761,6 @@ async def heartbeat_with_context(*details: Any) -> None:
732761 finally :
733762 if shared_manager :
734763 await shared_manager .unregister_heartbeater (info .task_token )
735-
736764 # Otherwise for async activity, just run
737765 return await input .fn (* input .args )
738766
@@ -764,6 +792,7 @@ def _execute_sync_activity(
764792 temporalio .converter .PayloadConverter ,
765793 ],
766794 runtime_metric_meter : Optional [temporalio .common .MetricMeter ],
795+ cancellation_details : temporalio .activity ._ActivityCancellationDetailsHolder ,
767796 fn : Callable [..., Any ],
768797 * args : Any ,
769798) -> Any :
@@ -795,6 +824,7 @@ def _execute_sync_activity(
795824 else cancel_thread_raiser .shielded ,
796825 payload_converter_class_or_instance = payload_converter_class_or_instance ,
797826 runtime_metric_meter = runtime_metric_meter ,
827+ cancellation_details = cancellation_details ,
798828 )
799829 )
800830 return fn (* args )
0 commit comments