@@ -208,6 +208,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
208208 self ._worker_level_failure_exception_types = (
209209 det .worker_level_failure_exception_types
210210 )
211+ self ._primary_task_initter : Optional [Callable [[], None ]] = None
211212 self ._primary_task : Optional [asyncio .Task [None ]] = None
212213 self ._time_ns = 0
213214 self ._cancel_requested = False
@@ -356,39 +357,24 @@ def activate(
356357 self ._current_thread_id = threading .get_ident ()
357358 activation_err : Optional [Exception ] = None
358359 try :
359- # Split into job sets with patches, then signals + updates, then
360- # non-queries, then queries
361- start_job = None
362- job_sets : List [
363- List [temporalio .bridge .proto .workflow_activation .WorkflowActivationJob ]
364- ] = [[], [], [], []]
360+ # Apply every job, running the loop afterward
361+ is_query = False
365362 for job in act .jobs :
366- if job .HasField ("notify_has_patch" ):
367- job_sets [0 ].append (job )
368- elif job .HasField ("signal_workflow" ) or job .HasField ("do_update" ):
369- job_sets [1 ].append (job )
370- elif not job .HasField ("query_workflow" ):
371- if job .HasField ("initialize_workflow" ):
372- start_job = job .initialize_workflow
373- job_sets [2 ].append (job )
374- else :
375- job_sets [3 ].append (job )
376-
377- if start_job :
378- self ._workflow_input = self ._make_workflow_input (start_job )
379-
380- # Apply every job set, running after each set
381- for index , job_set in enumerate (job_sets ):
382- if not job_set :
383- continue
384- for job in job_set :
385- # Let errors bubble out of these to the caller to fail the task
386- self ._apply (job )
387-
388- # Run one iteration of the loop. We do not allow conditions to
389- # be checked in patch jobs (first index) or query jobs (last
390- # index).
391- self ._run_once (check_conditions = index == 1 or index == 2 )
363+ if job .HasField ("initialize_workflow" ):
364+ self ._workflow_input = self ._make_workflow_input (
365+ job .initialize_workflow
366+ )
367+ # Let errors bubble out of these to the caller to fail the task
368+ self ._apply (job )
369+ if job .HasField ("query_workflow" ):
370+ is_query = True
371+
372+ # Ensure the main loop is called, and called last, if needed
373+ if self ._primary_task_initter is not None and self ._primary_task is None :
374+ self ._primary_task_initter ()
375+ # Conditions are not checked on query activations. Query activations always come without
376+ # any other jobs.
377+ self ._run_once (check_conditions = not is_query )
392378 except Exception as err :
393379 # We want some errors during activation, like those that can happen
394380 # during payload conversion, to be able to fail the workflow not the
@@ -508,6 +494,17 @@ def _apply_cancel_workflow(
508494 # workflow the ability to receive the cancellation, so we must defer
509495 # this cancellation to the next iteration of the event loop.
510496 self .call_soon (self ._primary_task .cancel )
497+ elif self ._primary_task_initter :
498+ # If we're being cancelled before ever being started, we need to run the cancel
499+ # after initialization
500+ old_initter = self ._primary_task_initter
501+
502+ def init_then_cancel ():
503+ old_initter ()
504+ if self ._primary_task :
505+ self .call_soon (self ._primary_task .cancel )
506+
507+ self ._primary_task_initter = init_then_cancel
511508
512509 def _apply_do_update (
513510 self , job : temporalio .bridge .proto .workflow_activation .DoUpdate
@@ -885,14 +882,19 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
885882 return
886883 raise
887884
888- if not self ._workflow_input :
889- raise RuntimeError (
890- "Expected workflow input to be set. This is an SDK Python bug."
885+ def primary_initter ():
886+ if not self ._workflow_input :
887+ raise RuntimeError (
888+ "Expected workflow input to be set. This is an SDK Python bug."
889+ )
890+ self ._primary_task = self .create_task (
891+ self ._run_top_level_workflow_function (
892+ run_workflow (self ._workflow_input )
893+ ),
894+ name = "run" ,
891895 )
892- self ._primary_task = self .create_task (
893- self ._run_top_level_workflow_function (run_workflow (self ._workflow_input )),
894- name = "run" ,
895- )
896+
897+ self ._primary_task_initter = primary_initter
896898
897899 def _apply_update_random_seed (
898900 self , job : temporalio .bridge .proto .workflow_activation .UpdateRandomSeed
0 commit comments