Skip to content

Commit 63f6942

Browse files
authored
Fix graph execution bug with multiple joins downstream of same fork (pydantic#3337)
1 parent 9795d74 commit 63f6942

File tree

3 files changed

+157
-23
lines changed

3 files changed

+157
-23
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
148148
parent_forks: dict[JoinID, ParentFork[NodeID]]
149149
"""Parent fork information for each join node."""
150150

151+
intermediate_join_nodes: dict[JoinID, set[JoinID]]
152+
"""For each join, the set of other joins that appear between it and its parent fork.
153+
154+
Used to determine which joins are "final" (have no other joins as intermediates) and
155+
which joins should preserve fork stacks when proceeding downstream."""
156+
151157
def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
152158
"""Get the parent fork information for a join node.
153159
@@ -165,6 +171,24 @@ def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
165171
raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
166172
return result
167173

174+
def is_final_join(self, join_id: JoinID) -> bool:
175+
"""Check if a join is 'final' (has no downstream joins with the same parent fork).
176+
177+
A join is non-final if it appears as an intermediate node for another join
178+
with the same parent fork.
179+
180+
Args:
181+
join_id: The ID of the join node
182+
183+
Returns:
184+
True if the join is final, False if it's non-final
185+
"""
186+
# Check if this join appears in any other join's intermediate_join_nodes
187+
for intermediate_joins in self.intermediate_join_nodes.values():
188+
if join_id in intermediate_joins:
189+
return False
190+
return True
191+
168192
async def run(
169193
self,
170194
*,
@@ -517,7 +541,14 @@ async def iter_graph( # noqa C901
517541
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
518542
for i, x in enumerate(result.fork_stack[::-1]):
519543
if x.fork_id == parent_fork_id:
520-
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
544+
# For non-final joins (those that are intermediate nodes of other joins),
545+
# preserve the fork stack so downstream joins can still associate with the same fork run
546+
if self.graph.is_final_join(result.join_id):
547+
# Final join: remove the parent fork from the stack
548+
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
549+
else:
550+
# Non-final join: preserve the fork stack
551+
downstream_fork_stack = result.fork_stack
521552
fork_run_id = x.node_run_id
522553
break
523554
else: # pragma: no cover
@@ -535,13 +566,9 @@ async def iter_graph( # noqa C901
535566
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
536567
if join_state.cancelled_sibling_tasks:
537568
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
538-
if task_result.source_is_finished: # pragma: no branch
539-
await self._finish_task(task_result.source.task_id)
540569
else:
541570
for new_task in maybe_overridden_result:
542571
self.active_tasks[new_task.task_id] = new_task
543-
if task_result.source_is_finished:
544-
await self._finish_task(task_result.source.task_id)
545572

546573
tasks_by_id_values = list(self.active_tasks.values())
547574
join_tasks: list[GraphTask] = []
@@ -566,28 +593,61 @@ async def iter_graph( # noqa C901
566593
new_task_ids = {t.task_id for t in maybe_overridden_result}
567594
for t in task_result.result:
568595
if t.task_id not in new_task_ids:
569-
await self._finish_task(t.task_id)
596+
await self._finish_task(t.task_id, t.node_id)
570597
self._handle_execution_request(maybe_overridden_result)
571598

599+
if task_result.source_is_finished:
600+
await self._finish_task(task_result.source.task_id, task_result.source.node_id)
601+
572602
if not self.active_tasks:
573603
# if there are no active tasks, we'll be waiting forever for the next result..
574604
break
575605

576606
if self.active_reducers: # pragma: no branch
577-
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
578-
# downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
579-
# deeper reducer could produce new tasks in the "prefix" reducer.)
580-
active_fork_stacks = [
581-
join_state.downstream_fork_stack for join_state in self.active_reducers.values()
582-
]
607+
# In this case, there are no pending tasks. We can therefore finalize all active reducers
608+
# that don't have intermediate joins which are also active reducers. If a join J2 has an
609+
# intermediate join J1 that shares the same parent fork run, we must finalize J1 first
610+
# because it might produce items that feed into J2.
583611
for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
584-
fork_stack = join_state.downstream_fork_stack
585-
if any(
586-
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
587-
for afs in active_fork_stacks
588-
):
589-
# this join_state is a strict prefix for one of the other active join_states
590-
continue # pragma: no cover # It's difficult to cover this
612+
# Check if this join has any intermediate joins that are also active reducers
613+
should_skip = False
614+
intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
615+
616+
# Get the parent fork for this join to use for comparison
617+
join_parent_fork = self.graph.get_parent_fork(join_id)
618+
619+
for intermediate_join_id in intermediate_joins:
620+
# Check if the intermediate join is also an active reducer with matching fork run
621+
for (other_join_id, _), other_join_state in self.active_reducers.items():
622+
if other_join_id == intermediate_join_id:
623+
# Check if they share the same fork run for this join's parent fork
624+
# by finding the parent fork's node_run_id in both fork stacks
625+
join_parent_fork_run_id = None
626+
other_parent_fork_run_id = None
627+
628+
for fsi in join_state.downstream_fork_stack: # pragma: no branch
629+
if fsi.fork_id == join_parent_fork.fork_id:
630+
join_parent_fork_run_id = fsi.node_run_id
631+
break
632+
633+
for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
634+
if fsi.fork_id == join_parent_fork.fork_id:
635+
other_parent_fork_run_id = fsi.node_run_id
636+
break
637+
638+
if (
639+
join_parent_fork_run_id
640+
and other_parent_fork_run_id
641+
and join_parent_fork_run_id == other_parent_fork_run_id
642+
): # pragma: no branch
643+
should_skip = True
644+
break
645+
if should_skip:
646+
break
647+
648+
if should_skip:
649+
continue
650+
591651
self.active_reducers.pop(
592652
(join_id, fork_run_id)
593653
) # we're handling it now, so we can pop it
@@ -610,7 +670,7 @@ async def iter_graph( # noqa C901
610670
# Same note as above about how this is theoretically reachable but we should
611671
# just get coverage by unifying the code paths
612672
if t.task_id not in new_task_ids: # pragma: no cover
613-
await self._finish_task(t.task_id)
673+
await self._finish_task(t.task_id, t.node_id)
614674
self._handle_execution_request(maybe_overridden_result)
615675
except GeneratorExit:
616676
self._task_group.cancel_scope.cancel()
@@ -620,7 +680,8 @@ async def iter_graph( # noqa C901
620680
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
621681
)
622682

623-
async def _finish_task(self, task_id: TaskID) -> None:
683+
async def _finish_task(self, task_id: TaskID, node_id: str) -> None:
684+
# node_id is just included for debugging right now
624685
scope = self.cancel_scopes.pop(task_id, None)
625686
if scope is not None:
626687
scope.cancel()
@@ -837,7 +898,7 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
837898
else:
838899
pass
839900
for task_id in task_ids_to_cancel:
840-
await self._finish_task(task_id)
901+
await self._finish_task(task_id, 'sibling')
841902

842903

843904
def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def build(self, validate_graph_structure: bool = True) -> Graph[StateT, DepsT, G
658658
if validate_graph_structure:
659659
_validate_graph_structure(nodes, edges_by_source)
660660
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
661+
intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
661662

662663
return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
663664
name=self.name,
@@ -668,6 +669,7 @@ def build(self, validate_graph_structure: bool = True) -> Graph[StateT, DepsT, G
668669
nodes=nodes,
669670
edges_by_source=edges_by_source,
670671
parent_forks=parent_forks,
672+
intermediate_join_nodes=intermediate_join_nodes,
671673
auto_instrument=self.auto_instrument,
672674
)
673675

@@ -948,6 +950,40 @@ def _handle_path(path: Path, last_source_id: NodeID):
948950
return dominating_forks
949951

950952

953+
def _compute_intermediate_join_nodes(
954+
nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
955+
) -> dict[JoinID, set[JoinID]]:
956+
"""Compute which joins have other joins as intermediate nodes.
957+
958+
A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
959+
(as computed relative to J2's parent fork).
960+
961+
This information is used to determine:
962+
1. Which joins are "final" (have no other joins in their intermediate_nodes)
963+
2. When selecting which reducer to proceed with when there are no active tasks
964+
965+
Args:
966+
nodes: All nodes in the graph
967+
parent_forks: Parent fork information for each join
968+
969+
Returns:
970+
A mapping from each join to the set of joins that are intermediate to it
971+
"""
972+
intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
973+
974+
for join_id, parent_fork in parent_forks.items():
975+
intermediate_joins = set[JoinID]()
976+
for intermediate_node_id in parent_fork.intermediate_nodes:
977+
# Check if this intermediate node is also a join
978+
intermediate_node = nodes.get(intermediate_node_id)
979+
if isinstance(intermediate_node, Join):
980+
# Add it regardless of whether it has the same parent fork
981+
intermediate_joins.add(JoinID(intermediate_node_id))
982+
intermediate_join_nodes[join_id] = intermediate_joins
983+
984+
return intermediate_join_nodes
985+
986+
951987
def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
952988
node_id_remapping = _build_placeholder_node_id_remapping(nodes)
953989
replaced_nodes = {

tests/graph/beta/test_graph_execution.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from pydantic_graph.beta import GraphBuilder, StepContext
10-
from pydantic_graph.beta.join import ReduceFirstValue, reduce_list_append
10+
from pydantic_graph.beta.join import ReduceFirstValue, reduce_list_append, reduce_list_extend
1111

1212
pytestmark = pytest.mark.anyio
1313

@@ -343,3 +343,40 @@ async def step3(ctx: StepContext[ExecutionState, None, list[int]]) -> str:
343343
assert 'parallel-100' in state.log
344344
assert 'step3' in state.log
345345
assert result == 'Result: 152' # (50+1) + (100+1) = 152
346+
347+
348+
async def test_multiple_sequential_joins():
349+
g = GraphBuilder(output_type=list[int])
350+
351+
@g.step
352+
async def source(ctx: StepContext[None, None, None]) -> int:
353+
return 10
354+
355+
@g.step
356+
async def add_one(ctx: StepContext[None, None, int]) -> list[int]:
357+
return [ctx.inputs + 1]
358+
359+
@g.step
360+
async def add_two(ctx: StepContext[None, None, int]) -> list[int]:
361+
return [ctx.inputs + 2]
362+
363+
@g.step
364+
async def add_three(ctx: StepContext[None, None, int]) -> list[int]:
365+
return [ctx.inputs + 3]
366+
367+
collect = g.join(reduce_list_extend, initial_factory=list[int], parent_fork_id='source_fork', node_id='collect')
368+
mediator = g.join(reduce_list_extend, initial_factory=list[int], node_id='mediator')
369+
370+
# Broadcasting: send the value from source to all three steps
371+
g.add(
372+
g.edge_from(g.start_node).to(source),
373+
g.edge_from(source).to(add_one, add_two, add_three, fork_id='source_fork'),
374+
g.edge_from(add_one, add_two).to(mediator),
375+
g.edge_from(mediator).to(collect),
376+
g.edge_from(add_three).to(collect),
377+
g.edge_from(collect).to(g.end_node),
378+
)
379+
380+
graph = g.build()
381+
result = await graph.run()
382+
assert sorted(result) == [11, 12, 13]

0 commit comments

Comments
 (0)