From 03228782059282de355a7ada358293f1505090c1 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Mon, 1 Dec 2025 09:22:54 -0600 Subject: [PATCH] Revert "Simplify broadcast-join algorithm in cudf-polars (#20724)" This reverts commit 67d4a41dda7cee02bddab249bc5fbe0b2ec44903. --- .../cudf_polars/experimental/join.py | 67 +++++++++++++------ .../experimental/rapidsmpf/join.py | 34 ++++++---- 2 files changed, 69 insertions(+), 32 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/join.py b/python/cudf_polars/cudf_polars/experimental/join.py index 499786ffcaa..9cd06330be1 100644 --- a/python/cudf_polars/cudf_polars/experimental/join.py +++ b/python/cudf_polars/cudf_polars/experimental/join.py @@ -12,7 +12,7 @@ from cudf_polars.experimental.base import PartitionInfo, get_key_name from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node from cudf_polars.experimental.repartition import Repartition -from cudf_polars.experimental.shuffle import Shuffle +from cudf_polars.experimental.shuffle import Shuffle, _hash_partition_dataframe from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback if TYPE_CHECKING: @@ -344,36 +344,65 @@ def _( small_name = get_key_name(right) small_size = partition_info[right].count large_name = get_key_name(left) + large_on = ir.left_on else: small_side = "Left" small_name = get_key_name(left) small_size = partition_info[left].count large_name = get_key_name(right) + large_on = ir.right_on graph: MutableMapping[Any, Any] = {} out_name = get_key_name(ir) out_size = partition_info[ir].count - concat_name = f"concat-{out_name}" + split_name = f"split-{out_name}" + getit_name = f"getit-{out_name}" + inter_name = f"inter-{out_name}" - # Concatenate the small partitions - if small_size > 1: - graph[(concat_name, 0)] = ( - partial(_concat, context=context), - *((small_name, j) for j in range(small_size)), - ) - small_name = concat_name + # Split each large partition if we have + # multiple small partitions (unless this + # is an inner join) + split_large = ir.options[0] != "Inner" and small_size > 1 for part_out in range(out_size): - join_children = [(large_name, part_out), (small_name, 0)] - if small_side == "Left": - join_children.reverse() - graph[(out_name, part_out)] = ( - partial(ir.do_evaluate, context=context), - ir.left_on, - ir.right_on, - ir.options, - *join_children, - ) + if split_large: + graph[(split_name, part_out)] = ( + _hash_partition_dataframe, + (large_name, part_out), + part_out, + small_size, + None, + large_on, + ) + + _concat_list = [] + for j in range(small_size): + left_key: tuple[str, int] | tuple[str, int, int] + if split_large: + left_key = (getit_name, part_out, j) + graph[left_key] = (operator.getitem, (split_name, part_out), j) + else: + left_key = (large_name, part_out) + join_children = [left_key, (small_name, j)] + if small_side == "Left": + join_children.reverse() + + inter_key = (inter_name, part_out, j) + graph[(inter_name, part_out, j)] = ( + partial(ir.do_evaluate, context=context), + ir.left_on, + ir.right_on, + ir.options, + *join_children, + ) + _concat_list.append(inter_key) + if len(_concat_list) == 1: + graph[(out_name, part_out)] = graph.pop(_concat_list[0]) + else: + graph[(out_name, part_out)] = ( + partial(_concat, context=context), + *_concat_list, + ) return graph diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py index 360045ffc0f..d7a48fe5a3a 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py @@ -149,11 +149,11 @@ async def broadcast_join_node( ) await ch_out.send_metadata(context, output_metadata) - # Collect small-side - small_df = _concat( - *await get_small_table(context, small_child, small_ch), - context=ir_context, - ) + # Collect small-side chunks + small_dfs = await get_small_table(context, small_child, small_ch) + if ir.options[0] != "Inner": + # TODO: Use local repartitioning for non-inner joins + small_dfs = [_concat(*small_dfs, context=ir_context)] # Stream through large side, joining with the small-side while (msg := await large_ch.data.recv(context)) is not None: @@ -169,14 +169,22 @@ async def broadcast_join_node( ) # Perform the join - df = await asyncio.to_thread( - ir.do_evaluate, - *ir._non_child_args, - *( - [large_df, small_df] - if broadcast_side == "right" - else [small_df, large_df] - ), + df = _concat( + *[ + ( + await asyncio.to_thread( + ir.do_evaluate, + *ir._non_child_args, + *( + [large_df, small_df] + if broadcast_side == "right" + else [small_df, large_df] + ), + context=ir_context, + ) + ) + for small_df in small_dfs + ], context=ir_context, )