@@ -370,8 +370,9 @@ async def _run_activities(self) -> None:
370370
371371 if task .HasField ("start" ):
372372 # Cancelled event and sync field will be updated inside
373- # _run_activity when the activity function is obtained
374- activity = _RunningActivity ()
373+ # _run_activity when the activity function is obtained. Max
374+ # size of 1000 should be plenty for the heartbeat queue.
375+ activity = _RunningActivity (pending_heartbeats = asyncio .Queue (1000 ))
375376 activity .task = asyncio .create_task (
376377 self ._run_activity (task .task_token , task .start , activity )
377378 )
@@ -409,22 +410,27 @@ def _heartbeat_activity(self, task_token: bytes, *details: Any) -> None:
409410 logger = temporalio .activity .logger
410411 activity = self ._running_activities .get (task_token )
411412 if activity and not activity .done :
412- # Just set as next pending if one is already running
413- coro = self ._heartbeat_activity_async (
414- logger , activity , task_token , * details
413+ # Put on queue and schedule a task. We will let the queue-full error
414+ # be thrown here
415+ activity .pending_heartbeats .put_nowait (details )
416+ activity .last_heartbeat_task = asyncio .create_task (
417+ self ._heartbeat_activity_async (logger , activity , task_token )
415418 )
416- if activity .current_heartbeat_task :
417- activity .pending_heartbeat = coro
418- else :
419- activity .current_heartbeat_task = asyncio .create_task (coro )
420419
421420 async def _heartbeat_activity_async (
422421 self ,
423422 logger : logging .LoggerAdapter ,
424423 activity : _RunningActivity ,
425424 task_token : bytes ,
426- * details : Any ,
427425 ) -> None :
426+ # Drain the queue, only taking the last value to actually heartbeat
427+ details : Optional [Iterable [Any ]] = None
428+ while not activity .pending_heartbeats .empty ():
429+ details = activity .pending_heartbeats .get_nowait ()
430+ if details is None :
431+ return
432+
433+ # Perform the heartbeat
428434 try :
429435 heartbeat = temporalio .bridge .proto .ActivityHeartbeat (task_token = task_token )
430436 if details :
@@ -437,16 +443,7 @@ async def _heartbeat_activity_async(
437443 )
438444 logger .debug ("Recording heartbeat with details %s" , details )
439445 self ._bridge_worker .record_activity_heartbeat (heartbeat )
440- # If there is one pending, schedule it
441- if activity .pending_heartbeat :
442- activity .current_heartbeat_task = asyncio .create_task (
443- activity .pending_heartbeat
444- )
445- activity .pending_heartbeat = None
446- else :
447- activity .current_heartbeat_task = None
448446 except Exception as err :
449- activity .current_heartbeat_task = None
450447 # If the activity is done, nothing we can do but log
451448 if activity .done :
452449 logger .exception (
@@ -696,12 +693,12 @@ async def _run_activity(
696693
697694 # Do final completion
698695 try :
699- # We mark the activity as done and let the currently running (and next
700- # pending) heartbeat task finish
696+ # We mark the activity as done and let the currently running
697+ # heartbeat task finish
701698 running_activity .done = True
702- while running_activity .current_heartbeat_task :
699+ if running_activity .last_heartbeat_task :
703700 try :
704- await running_activity .current_heartbeat_task
701+ await running_activity .last_heartbeat_task
705702 except :
706703 # Should never happen because it's trapped in-task
707704 temporalio .activity .logger .exception (
@@ -749,12 +746,12 @@ class _ActivityDefinition:
749746
750747@dataclass
751748class _RunningActivity :
749+ pending_heartbeats : asyncio .Queue [Iterable [Any ]]
752750 # Most of these optional values are set before use
753751 info : Optional [temporalio .activity .Info ] = None
754752 task : Optional [asyncio .Task ] = None
755753 cancelled_event : Optional [temporalio .activity ._CompositeEvent ] = None
756- pending_heartbeat : Optional [Coroutine ] = None
757- current_heartbeat_task : Optional [asyncio .Task ] = None
754+ last_heartbeat_task : Optional [asyncio .Task ] = None
758755 sync : bool = False
759756 done : bool = False
760757 cancelled_by_request : bool = False
@@ -895,19 +892,16 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
895892 # loop (even though it's sync). So we need a call that puts the
896893 # context back on the activity and calls heartbeat, then another
897894 # call schedules it.
898- def heartbeat_with_context (* details : Any ) -> None :
895+ async def heartbeat_with_context (* details : Any ) -> None :
899896 temporalio .activity ._Context .set (ctx )
900897 assert orig_heartbeat
901898 orig_heartbeat (* details )
902899
903- def thread_safe_heartbeat (* details : Any ) -> None :
904- # TODO(cretz): Final heartbeat can be flaky if we don't wait on
905- # result here, but waiting on result of
906- # asyncio.run_coroutine_threadsafe times out in rare cases.
907- # Need more investigation: https://github.com/temporalio/sdk-python/issues/12
908- loop .call_soon_threadsafe (heartbeat_with_context , * details )
909-
910- ctx .heartbeat = thread_safe_heartbeat
900+ # Invoke the async heartbeat waiting a max of 10 seconds for
901+ # accepting
902+ ctx .heartbeat = lambda * details : asyncio .run_coroutine_threadsafe (
903+ heartbeat_with_context (* details ), loop
904+ ).result (10 )
911905
912906 # For heartbeats, we use the existing heartbeat callable for thread
913907 # pool executors or a multiprocessing queue for others
@@ -917,7 +911,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
917911 # Should always be present in worker, pre-checked on init
918912 shared_manager = input ._worker ._config ["shared_state_manager" ]
919913 assert shared_manager
920- heartbeat = shared_manager .register_heartbeater (
914+ heartbeat = await shared_manager .register_heartbeater (
921915 info .task_token , ctx .heartbeat
922916 )
923917
@@ -935,7 +929,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
935929 )
936930 finally :
937931 if shared_manager :
938- shared_manager .unregister_heartbeater (info .task_token )
932+ await shared_manager .unregister_heartbeater (info .task_token )
939933
940934 # Otherwise for async activity, just run
941935 return await input .fn (* input .args )
@@ -1032,7 +1026,7 @@ def new_event(self) -> threading.Event:
10321026 raise NotImplementedError
10331027
10341028 @abstractmethod
1035- def register_heartbeater (
1029+ async def register_heartbeater (
10361030 self , task_token : bytes , heartbeat : Callable [..., None ]
10371031 ) -> SharedHeartbeatSender :
10381032 """Register a heartbeat function.
@@ -1048,7 +1042,7 @@ def register_heartbeater(
10481042 raise NotImplementedError
10491043
10501044 @abstractmethod
1051- def unregister_heartbeater (self , task_token : bytes ) -> None :
1045+ async def unregister_heartbeater (self , task_token : bytes ) -> None :
10521046 """Unregisters a previously registered heartbeater for the task
10531047 token. This should also flush any pending heartbeats.
10541048 """
@@ -1084,12 +1078,12 @@ def __init__(
10841078 1000
10851079 )
10861080 self ._heartbeats : Dict [bytes , Callable [..., None ]] = {}
1087- self ._heartbeat_completions : Dict [bytes , Callable [[], None ] ] = {}
1081+ self ._heartbeat_completions : Dict [bytes , Callable ] = {}
10881082
10891083 def new_event (self ) -> threading .Event :
10901084 return self ._mgr .Event ()
10911085
1092- def register_heartbeater (
1086+ async def register_heartbeater (
10931087 self , task_token : bytes , heartbeat : Callable [..., None ]
10941088 ) -> SharedHeartbeatSender :
10951089 self ._heartbeats [task_token ] = heartbeat
@@ -1098,17 +1092,19 @@ def register_heartbeater(
10981092 self ._queue_poller_executor .submit (self ._heartbeat_processor )
10991093 return _MultiprocessingSharedHeartbeatSender (self ._heartbeat_queue )
11001094
1101- def unregister_heartbeater (self , task_token : bytes ) -> None :
1102- # Put a completion on the queue and wait for it to happen
1103- flush_complete = threading .Event ()
1104- self ._heartbeat_completions [task_token ] = flush_complete .set
1095+ async def unregister_heartbeater (self , task_token : bytes ) -> None :
1096+ # Put a callback on the queue and wait for it to happen
1097+ loop = asyncio .get_running_loop ()
1098+ finish_event = asyncio .Event ()
1099+ self ._heartbeat_completions [task_token ] = lambda : loop .call_soon_threadsafe (
1100+ finish_event .set
1101+ )
11051102 try :
1106- # 30 seconds to put complete, 30 to get notified should be plenty
1103+ # We only give the queue a few seconds to have enough room
11071104 self ._heartbeat_queue .put (
1108- (task_token , _multiprocess_heartbeat_complete ), True , 30
1105+ (task_token , _multiprocess_heartbeat_complete ), True , 5
11091106 )
1110- if not flush_complete .wait (30 ):
1111- raise RuntimeError ("Timeout waiting for heartbeat flush" )
1107+ await finish_event .wait ()
11121108 finally :
11131109 del self ._heartbeat_completions [task_token ]
11141110
0 commit comments