Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ add_library(
src/join/filter_join_indices_kernel_primitive.cu
src/join/filtered_join.cu
src/join/hash_join.cu
src/join/jit_filter_join_indices.cu
src/join/join.cu
src/join/join_utils.cu
src/join/key_remapping.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ ConfigureNVBench(
join/mixed_join.cu
join/distinct_join.cu
join/filter_join_indices.cu
join/jit_filter_join_indices.cpp
join/multiplicity_join.cu
join/sort_merge_join.cu
join/join_heuristics.cu
Expand Down
157 changes: 157 additions & 0 deletions cpp/benchmarks/join/jit_filter_join_indices.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <cudf/column/column_factories.hpp>
#include <cudf/join/join.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <benchmarks/common/generate_input.hpp>
#include <benchmarks/fixture/benchmark_fixture.hpp>

#include <nvbench/nvbench.cuh>

#include <memory>
#include <random>
#include <vector>

class JitFilterJoinIndicesBench : public cudf::benchmark {
private:
std::unique_ptr<cudf::table> left_table;
std::unique_ptr<cudf::table> right_table;
std::unique_ptr<rmm::device_uvector<cudf::size_type>> left_indices;
std::unique_ptr<rmm::device_uvector<cudf::size_type>> right_indices;

public:
void SetUp(int64_t num_rows, double selectivity)
{
// Create test tables with integer columns
auto left_col0_data = cudf::detail::make_counting_transform_iterator(
0, [](auto i) { return static_cast<int32_t>(i); });
auto left_col1_data = cudf::detail::make_counting_transform_iterator(
0, [](auto i) { return static_cast<int32_t>(i * 2); });

auto left_col0 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, num_rows);
auto left_col1 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, num_rows);

thrust::copy(left_col0_data, left_col0_data + num_rows,
left_col0->mutable_view().data<int32_t>());
thrust::copy(left_col1_data, left_col1_data + num_rows,
left_col1->mutable_view().data<int32_t>());

std::vector<std::unique_ptr<cudf::column>> left_columns;
left_columns.push_back(std::move(left_col0));
left_columns.push_back(std::move(left_col1));
left_table = std::make_unique<cudf::table>(std::move(left_columns));

// Create right table with similar pattern
auto right_col0_data = cudf::detail::make_counting_transform_iterator(
0, [](auto i) { return static_cast<int32_t>(i); });
auto right_col1_data = cudf::detail::make_counting_transform_iterator(
0, [selectivity](auto i) {
return static_cast<int32_t>(i * 2 - (selectivity > 0.5 ? 1 : 10));
});

auto right_col0 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, num_rows);
auto right_col1 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, num_rows);

thrust::copy(right_col0_data, right_col0_data + num_rows,
right_col0->mutable_view().data<int32_t>());
thrust::copy(right_col1_data, right_col1_data + num_rows,
right_col1->mutable_view().data<int32_t>());

std::vector<std::unique_ptr<cudf::column>> right_columns;
right_columns.push_back(std::move(right_col0));
right_columns.push_back(std::move(right_col1));
right_table = std::make_unique<cudf::table>(std::move(right_columns));

// Create join indices (simulate all pairs matching from equality join)
left_indices = std::make_unique<rmm::device_uvector<cudf::size_type>>(num_rows, cudf::get_default_stream());
right_indices = std::make_unique<rmm::device_uvector<cudf::size_type>>(num_rows, cudf::get_default_stream());

auto counting_iter = cudf::detail::make_counting_transform_iterator(
0, [](auto i) { return static_cast<cudf::size_type>(i); });
thrust::copy(counting_iter, counting_iter + num_rows, left_indices->begin());
thrust::copy(counting_iter, counting_iter + num_rows, right_indices->begin());
}

void BenchmarkJitFilterJoinIndices(nvbench::state& state,
cudf::join_kind join_kind,
std::string const& predicate_code)
{
auto const num_rows = static_cast<int64_t>(state.get_int64("num_rows"));
auto const selectivity = state.get_float64("selectivity");

SetUp(num_rows, selectivity);

cudf::device_span<cudf::size_type const> left_span{left_indices->data(), left_indices->size()};
cudf::device_span<cudf::size_type const> right_span{right_indices->data(), right_indices->size()};

state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
auto [filtered_left, filtered_right] = cudf::jit_filter_join_indices(
left_table->view(),
right_table->view(),
left_span,
right_span,
predicate_code,
join_kind);
});

state.add_buffer_size(num_rows * sizeof(cudf::size_type), "input_indices");
state.add_buffer_size(filtered_left->size() * sizeof(cudf::size_type), "output_indices");
}
};

void jit_filter_join_indices_inner_join(nvbench::state& state)
{
JitFilterJoinIndicesBench benchmark;
std::string predicate_code = R"(
__device__ bool predicate(int32_t left_val, int32_t right_val) {
return left_val > right_val;
}
)";
benchmark.BenchmarkJitFilterJoinIndices(state, cudf::join_kind::INNER_JOIN, predicate_code);
}

void jit_filter_join_indices_left_join(nvbench::state& state)
{
JitFilterJoinIndicesBench benchmark;
std::string predicate_code = R"(
__device__ bool predicate(int32_t left_val, int32_t right_val) {
return left_val > right_val;
}
)";
benchmark.BenchmarkJitFilterJoinIndices(state, cudf::join_kind::LEFT_JOIN, predicate_code);
}

void jit_filter_join_indices_full_join(nvbench::state& state)
{
JitFilterJoinIndicesBench benchmark;
std::string predicate_code = R"(
__device__ bool predicate(int32_t left_val, int32_t right_val) {
return left_val > right_val;
}
)";
benchmark.BenchmarkJitFilterJoinIndices(state, cudf::join_kind::FULL_JOIN, predicate_code);
}

NVBENCH_BENCH(jit_filter_join_indices_inner_join)
.set_name("jit_filter_join_indices_inner")
.add_int64_axis("num_rows", {10'000, 100'000, 1'000'000})
.add_float64_axis("selectivity", {0.1, 0.5, 0.9});

NVBENCH_BENCH(jit_filter_join_indices_left_join)
.set_name("jit_filter_join_indices_left")
.add_int64_axis("num_rows", {10'000, 100'000, 1'000'000})
.add_float64_axis("selectivity", {0.1, 0.5, 0.9});

NVBENCH_BENCH(jit_filter_join_indices_full_join)
.set_name("jit_filter_join_indices_full")
.add_int64_axis("num_rows", {10'000, 100'000, 1'000'000})
.add_float64_axis("selectivity", {0.1, 0.5, 0.9});
72 changes: 72 additions & 0 deletions cpp/include/cudf/join/join.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,78 @@ filter_join_indices(cudf::table_view const& left,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief JIT-based filtering of join result indices using string predicate.
*
* This function provides a JIT-compiled alternative to filter_join_indices(),
* taking a string-based predicate that gets compiled to optimized GPU code.
*
* The behavior depends on the join type (same as filter_join_indices):
* - INNER_JOIN: Only pairs that satisfy the predicate and have valid indices are kept.
* - LEFT_JOIN: All left rows are preserved. Failed predicates nullify right indices.
* - FULL_JOIN: All rows from both sides are preserved. Failed predicates create separate pairs.
*
* ## Usage Pattern
*
* Similar to filter_join_indices but uses JIT compilation for better performance:
*
* @code{.cpp}
* // Step 1: Perform equality-based hash join (same as before)
* auto hash_joiner = cudf::hash_join(right_equality_table, null_equality::EQUAL);
* auto [left_indices, right_indices] = hash_joiner.inner_join(left_equality_table);
*
* // Step 2: Apply JIT-compiled conditional filter
* std::string predicate_code = R"(
* __device__ bool predicate(double left_val, double right_val) {
* return left_val > right_val;
* }
* )";
* auto [filtered_left, filtered_right] = cudf::jit_filter_join_indices(
* left_conditional_table, // Table with columns referenced by predicate
* right_conditional_table, // Table with columns referenced by predicate
* *left_indices, // Indices from hash join
* *right_indices, // Indices from hash join
* predicate_code, // JIT-compiled predicate function
* cudf::join_kind::INNER_JOIN);
* @endcode
*
* ## Predicate Function Requirements
*
* The predicate_code must define a device function with signature:
* ```cpp
* __device__ bool predicate(T1 left_col0, T2 left_col1, ..., T1 right_col0, T2 right_col1, ...)
* ```
* Where the parameters correspond to columns in left table followed by right table.
*
* @throw std::invalid_argument if join_kind is not INNER_JOIN, LEFT_JOIN, or FULL_JOIN.
* @throw std::invalid_argument if left_indices and right_indices have different sizes.
* @throw cudf::jit_compilation_error if predicate_code fails to compile.
*
* @param left The left table for predicate evaluation (conditional columns only).
* @param right The right table for predicate evaluation (conditional columns only).
* @param left_indices Device span of row indices in the left table from hash join.
* @param right_indices Device span of row indices in the right table from hash join.
* @param predicate_code String containing CUDA device code for predicate function.
* @param join_kind The type of join operation. Must be INNER_JOIN, LEFT_JOIN, or FULL_JOIN.
* @param is_ptx Whether predicate_code contains PTX assembly instead of CUDA C++.
* @param stream CUDA stream used for kernel launches and memory operations.
* @param mr Device memory resource used to allocate output indices.
*
* @return A pair of device vectors [filtered_left_indices, filtered_right_indices]
* corresponding to rows that satisfy the join semantics and predicate.
*/
std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
jit_filter_join_indices(cudf::table_view const& left,
cudf::table_view const& right,
cudf::device_span<size_type const> left_indices,
cudf::device_span<size_type const> right_indices,
std::string const& predicate_code,
cudf::join_kind join_kind,
bool is_ptx = false,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/** @} */ // end of group

} // namespace CUDF_EXPORT cudf
72 changes: 72 additions & 0 deletions cpp/src/jit/accessors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,77 @@ struct scalar_accessor {
}
};

// Join-specific accessors for indexed table access
template <typename T, int32_t Index>
struct join_left_column_accessor {
using type = T;
static constexpr int32_t index = Index;

static __device__ T element(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].template element<T>(row_idx);
}

static __device__ bool is_null(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].is_null(row_idx);
}

static __device__ bool is_valid(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].is_valid(row_idx);
}

static __device__ cuda::std::optional<T> nullable_element(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type thread_idx)
{
if (is_null(tables, row_idx, thread_idx)) { return cuda::std::nullopt; }
return element(tables, row_idx, thread_idx);
}
};

// Join-specific accessors for right table columns
template <typename T, int32_t Index>
struct join_right_column_accessor {
using type = T;
static constexpr int32_t index = Index;

static __device__ T element(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].template element<T>(row_idx);
}

static __device__ bool is_null(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].is_null(row_idx);
}

static __device__ bool is_valid(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type /* thread_idx */)
{
return tables[index].is_valid(row_idx);
}

static __device__ cuda::std::optional<T> nullable_element(cudf::column_device_view_core const* tables,
cudf::size_type row_idx,
cudf::size_type thread_idx)
{
if (is_null(tables, row_idx, thread_idx)) { return cuda::std::nullopt; }
return element(tables, row_idx, thread_idx);
}
};

} // namespace jit
} // namespace cudf
66 changes: 66 additions & 0 deletions cpp/src/join/jit/filter_join_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <cudf/column/column_device_view_base.cuh>
#include <cudf/detail/utilities/grid_1d.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <cuda/std/cstddef>

#include <jit/accessors.cuh>
#include <jit/span.cuh>

// clang-format off
// This header is an inlined header that defines the GENERIC_JOIN_FILTER_OP function. It is placed here
// so the symbols in the headers above can be used by it.
#include <cudf/detail/operation-udf.hpp>
// clang-format on

namespace cudf {
namespace join {
namespace jit {

template <bool has_user_data, typename... InputAccessors>
CUDF_KERNEL void filter_join_kernel(cudf::device_span<cudf::size_type const> left_indices,
cudf::device_span<cudf::size_type const> right_indices,
cudf::column_device_view_core const* left_tables,
cudf::column_device_view_core const* right_tables,
bool* predicate_results,
void* user_data)
{
auto const start = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
auto const size = left_indices.size();

for (auto i = start; i < size; i += stride) {
auto const left_idx = left_indices[i];
auto const right_idx = right_indices[i];

// Skip if either index is JoinNoMatch
if (left_idx == cudf::JoinNoMatch || right_idx == cudf::JoinNoMatch) {
predicate_results[i] = false;
continue;
}

bool result = false;

if constexpr (has_user_data) {
GENERIC_JOIN_FILTER_OP(user_data, i, &result,
InputAccessors::element(left_tables, left_idx, i)...,
InputAccessors::element(right_tables, right_idx, i)...);
} else {
GENERIC_JOIN_FILTER_OP(&result,
InputAccessors::element(left_tables, left_idx, i)...,
InputAccessors::element(right_tables, right_idx, i)...);
}

predicate_results[i] = result;
}
}

} // namespace jit
} // namespace join
} // namespace cudf
Loading