Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions python/cudf_polars/cudf_polars/experimental/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cudf_polars.experimental.base import PartitionInfo, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
from cudf_polars.experimental.repartition import Repartition
from cudf_polars.experimental.shuffle import Shuffle
from cudf_polars.experimental.shuffle import Shuffle, _hash_partition_dataframe
from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback

if TYPE_CHECKING:
Expand Down Expand Up @@ -344,36 +344,65 @@ def _(
small_name = get_key_name(right)
small_size = partition_info[right].count
large_name = get_key_name(left)
large_on = ir.left_on
else:
small_side = "Left"
small_name = get_key_name(left)
small_size = partition_info[left].count
large_name = get_key_name(right)
large_on = ir.right_on

graph: MutableMapping[Any, Any] = {}

out_name = get_key_name(ir)
out_size = partition_info[ir].count
concat_name = f"concat-{out_name}"
split_name = f"split-{out_name}"
getit_name = f"getit-{out_name}"
inter_name = f"inter-{out_name}"

# Concatenate the small partitions
if small_size > 1:
graph[(concat_name, 0)] = (
partial(_concat, context=context),
*((small_name, j) for j in range(small_size)),
)
small_name = concat_name
# Split each large partition if we have
# multiple small partitions (unless this
# is an inner join)
split_large = ir.options[0] != "Inner" and small_size > 1

for part_out in range(out_size):
join_children = [(large_name, part_out), (small_name, 0)]
if small_side == "Left":
join_children.reverse()
graph[(out_name, part_out)] = (
partial(ir.do_evaluate, context=context),
ir.left_on,
ir.right_on,
ir.options,
*join_children,
)
if split_large:
graph[(split_name, part_out)] = (
_hash_partition_dataframe,
(large_name, part_out),
part_out,
small_size,
None,
large_on,
)

_concat_list = []
for j in range(small_size):
left_key: tuple[str, int] | tuple[str, int, int]
if split_large:
left_key = (getit_name, part_out, j)
graph[left_key] = (operator.getitem, (split_name, part_out), j)
else:
left_key = (large_name, part_out)
join_children = [left_key, (small_name, j)]
if small_side == "Left":
join_children.reverse()

inter_key = (inter_name, part_out, j)
graph[(inter_name, part_out, j)] = (
partial(ir.do_evaluate, context=context),
ir.left_on,
ir.right_on,
ir.options,
*join_children,
)
_concat_list.append(inter_key)
if len(_concat_list) == 1:
graph[(out_name, part_out)] = graph.pop(_concat_list[0])
else:
graph[(out_name, part_out)] = (
partial(_concat, context=context),
*_concat_list,
)

return graph
34 changes: 21 additions & 13 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ async def broadcast_join_node(
)
await ch_out.send_metadata(context, output_metadata)

# Collect small-side
small_df = _concat(
*await get_small_table(context, small_child, small_ch),
context=ir_context,
)
# Collect small-side chunks
small_dfs = await get_small_table(context, small_child, small_ch)
if ir.options[0] != "Inner":
# TODO: Use local repartitioning for non-inner joins
small_dfs = [_concat(*small_dfs, context=ir_context)]

# Stream through large side, joining with the small-side
while (msg := await large_ch.data.recv(context)) is not None:
Expand All @@ -169,14 +169,22 @@ async def broadcast_join_node(
)

# Perform the join
df = await asyncio.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*(
[large_df, small_df]
if broadcast_side == "right"
else [small_df, large_df]
),
df = _concat(
*[
(
await asyncio.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*(
[large_df, small_df]
if broadcast_side == "right"
else [small_df, large_df]
),
context=ir_context,
)
)
for small_df in small_dfs
],
context=ir_context,
)

Expand Down