@@ -148,6 +148,12 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
148148 parent_forks : dict [JoinID , ParentFork [NodeID ]]
149149 """Parent fork information for each join node."""
150150
151+ intermediate_join_nodes : dict [JoinID , set [JoinID ]]
152+ """For each join, the set of other joins that appear between it and its parent fork.
153+
154+ Used to determine which joins are "final" (have no other joins as intermediates) and
155+ which joins should preserve fork stacks when proceeding downstream."""
156+
151157 def get_parent_fork (self , join_id : JoinID ) -> ParentFork [NodeID ]:
152158 """Get the parent fork information for a join node.
153159
@@ -165,6 +171,24 @@ def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
165171 raise RuntimeError (f'Node { join_id } is not a join node or did not have a dominating fork (this is a bug)' )
166172 return result
167173
174+ def is_final_join (self , join_id : JoinID ) -> bool :
175+ """Check if a join is 'final' (has no downstream joins with the same parent fork).
176+
177+ A join is non-final if it appears as an intermediate node for another join
178+ with the same parent fork.
179+
180+ Args:
181+ join_id: The ID of the join node
182+
183+ Returns:
184+ True if the join is final, False if it's non-final
185+ """
186+ # Check if this join appears in any other join's intermediate_join_nodes
187+ for intermediate_joins in self .intermediate_join_nodes .values ():
188+ if join_id in intermediate_joins :
189+ return False
190+ return True
191+
168192 async def run (
169193 self ,
170194 * ,
@@ -517,7 +541,14 @@ async def iter_graph( # noqa C901
517541 parent_fork_id = self .graph .get_parent_fork (result .join_id ).fork_id
518542 for i , x in enumerate (result .fork_stack [::- 1 ]):
519543 if x .fork_id == parent_fork_id :
520- downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
544+ # For non-final joins (those that are intermediate nodes of other joins),
545+ # preserve the fork stack so downstream joins can still associate with the same fork run
546+ if self .graph .is_final_join (result .join_id ):
547+ # Final join: remove the parent fork from the stack
548+ downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
549+ else :
550+ # Non-final join: preserve the fork stack
551+ downstream_fork_stack = result .fork_stack
521552 fork_run_id = x .node_run_id
522553 break
523554 else : # pragma: no cover
@@ -535,13 +566,9 @@ async def iter_graph( # noqa C901
535566 join_state .current = join_node .reduce (context , join_state .current , result .inputs )
536567 if join_state .cancelled_sibling_tasks :
537568 await self ._cancel_sibling_tasks (parent_fork_id , fork_run_id )
538- if task_result .source_is_finished : # pragma: no branch
539- await self ._finish_task (task_result .source .task_id )
540569 else :
541570 for new_task in maybe_overridden_result :
542571 self .active_tasks [new_task .task_id ] = new_task
543- if task_result .source_is_finished :
544- await self ._finish_task (task_result .source .task_id )
545572
546573 tasks_by_id_values = list (self .active_tasks .values ())
547574 join_tasks : list [GraphTask ] = []
@@ -566,28 +593,61 @@ async def iter_graph( # noqa C901
566593 new_task_ids = {t .task_id for t in maybe_overridden_result }
567594 for t in task_result .result :
568595 if t .task_id not in new_task_ids :
569- await self ._finish_task (t .task_id )
596+ await self ._finish_task (t .task_id , t . node_id )
570597 self ._handle_execution_request (maybe_overridden_result )
571598
599+ if task_result .source_is_finished :
600+ await self ._finish_task (task_result .source .task_id , task_result .source .node_id )
601+
572602 if not self .active_tasks :
573603 # if there are no active tasks, we'll be waiting forever for the next result..
574604 break
575605
576606 if self .active_reducers : # pragma: no branch
577- # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
578- # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
579- # deeper reducer could produce new tasks in the "prefix" reducer.)
580- active_fork_stacks = [
581- join_state .downstream_fork_stack for join_state in self .active_reducers .values ()
582- ]
607+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
608+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
609+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
610+ # because it might produce items that feed into J2.
583611 for (join_id , fork_run_id ), join_state in list (self .active_reducers .items ()):
584- fork_stack = join_state .downstream_fork_stack
585- if any (
586- len (afs ) > len (fork_stack ) and fork_stack == afs [: len (fork_stack )]
587- for afs in active_fork_stacks
588- ):
589- # this join_state is a strict prefix for one of the other active join_states
590- continue # pragma: no cover # It's difficult to cover this
612+ # Check if this join has any intermediate joins that are also active reducers
613+ should_skip = False
614+ intermediate_joins = self .graph .intermediate_join_nodes .get (join_id , set ())
615+
616+ # Get the parent fork for this join to use for comparison
617+ join_parent_fork = self .graph .get_parent_fork (join_id )
618+
619+ for intermediate_join_id in intermediate_joins :
620+ # Check if the intermediate join is also an active reducer with matching fork run
621+ for (other_join_id , _ ), other_join_state in self .active_reducers .items ():
622+ if other_join_id == intermediate_join_id :
623+ # Check if they share the same fork run for this join's parent fork
624+ # by finding the parent fork's node_run_id in both fork stacks
625+ join_parent_fork_run_id = None
626+ other_parent_fork_run_id = None
627+
628+ for fsi in join_state .downstream_fork_stack : # pragma: no branch
629+ if fsi .fork_id == join_parent_fork .fork_id :
630+ join_parent_fork_run_id = fsi .node_run_id
631+ break
632+
633+ for fsi in other_join_state .downstream_fork_stack : # pragma: no branch
634+ if fsi .fork_id == join_parent_fork .fork_id :
635+ other_parent_fork_run_id = fsi .node_run_id
636+ break
637+
638+ if (
639+ join_parent_fork_run_id
640+ and other_parent_fork_run_id
641+ and join_parent_fork_run_id == other_parent_fork_run_id
642+ ): # pragma: no branch
643+ should_skip = True
644+ break
645+ if should_skip :
646+ break
647+
648+ if should_skip :
649+ continue
650+
591651 self .active_reducers .pop (
592652 (join_id , fork_run_id )
593653 ) # we're handling it now, so we can pop it
@@ -610,7 +670,7 @@ async def iter_graph( # noqa C901
610670 # Same note as above about how this is theoretically reachable but we should
611671 # just get coverage by unifying the code paths
612672 if t .task_id not in new_task_ids : # pragma: no cover
613- await self ._finish_task (t .task_id )
673+ await self ._finish_task (t .task_id , t . node_id )
614674 self ._handle_execution_request (maybe_overridden_result )
615675 except GeneratorExit :
616676 self ._task_group .cancel_scope .cancel ()
@@ -620,7 +680,8 @@ async def iter_graph( # noqa C901
620680 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
621681 )
622682
623- async def _finish_task (self , task_id : TaskID ) -> None :
683+ async def _finish_task (self , task_id : TaskID , node_id : str ) -> None :
684+ # node_id is just included for debugging right now
624685 scope = self .cancel_scopes .pop (task_id , None )
625686 if scope is not None :
626687 scope .cancel ()
@@ -837,7 +898,7 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
837898 else :
838899 pass
839900 for task_id in task_ids_to_cancel :
840- await self ._finish_task (task_id )
901+ await self ._finish_task (task_id , 'sibling' )
841902
842903
843904def _is_any_iterable (x : Any ) -> TypeGuard [Iterable [Any ]]:
0 commit comments