Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 33 additions & 137 deletions cpp/src/io/parquet/decode_fixed.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "page_data.cuh"
Expand Down Expand Up @@ -272,19 +272,20 @@ __device__ inline void decode_fixed_width_split_values(
*
* @tparam decode_block_size Size of the thread block
* @tparam level_t Definition level type
* @tparam state_buf State buffer type
* @tparam is_nested Whether the type is nested
*
* @param target_value_count The target value count to process
* @param s Pointer to page state
* @param sb Pointer to state buffer
* @param s Pointer to page state
* @param def Pointer to the definition levels
* @param t Thread index
*
* @return Maximum depth valid count after processing
* @return Maximum depth valid count after skipping
*/
template <int decode_block_size, typename level_t, bool is_nested, int rolling_buf_size>
__device__ int skip_validity_and_row_indices_nonlist(
int value_count, int32_t target_value_count, page_state_s* s, level_t const* const def, int t)
template <int decode_block_size, typename level_t, bool is_nested>
__device__ int skip_validity_and_row_indices_nonlist(int32_t target_value_count,
page_state_s* s,
level_t const* const def,
int t)
{
int const max_def_level = [&]() {
if constexpr (is_nested) {
Expand All @@ -295,20 +296,12 @@ __device__ int skip_validity_and_row_indices_nonlist(
}();

int max_depth_valid_count = 0;
int value_count = 0;
while (value_count < target_value_count) {
int const batch_size = min(decode_block_size, target_value_count - value_count);

// definition level
int const is_valid = [&]() {
if (t >= batch_size) {
return 0;
} else if (def) {
int const def_level =
static_cast<int>(def[rolling_index<rolling_buf_size>(value_count + t)]);
return (def_level >= max_def_level) ? 1 : 0;
}
return 1;
}();
int const is_valid = (t >= batch_size) ? 0 : ((def[value_count + t] >= max_def_level) ? 1 : 0);

// thread and block validity count
using block_scan = cub::BlockScan<int, decode_block_size>;
Expand Down Expand Up @@ -363,14 +356,7 @@ __device__ int update_validity_and_row_indices_nested(
int const batch_size = min(max_batch_size, capped_target_value_count - value_count);

// definition level
int const d = [&]() {
if (t >= batch_size) {
return -1;
} else if (def) {
return static_cast<int>(def[rolling_index<state_buf::nz_buf_size>(value_count + t)]);
}
return 1;
}();
int const def_level = (t >= batch_size) ? -1 : def[value_count + t];

int const thread_value_count = t;
int const block_value_count = batch_size;
Expand All @@ -389,7 +375,7 @@ __device__ int update_validity_and_row_indices_nested(
for (int d_idx = 0; d_idx <= max_depth; d_idx++) {
auto& ni = s->nesting_info[d_idx];

int const is_valid = ((d >= ni.max_def_level) && in_row_bounds) ? 1 : 0;
int const is_valid = ((def_level >= ni.max_def_level) && in_row_bounds) ? 1 : 0;

// thread and block validity count
using block_scan = cub::BlockScan<int, decode_block_size>;
Expand Down Expand Up @@ -501,16 +487,8 @@ __device__ int update_validity_and_row_indices_flat(
int const in_row_bounds = (row_index < last_row);

// use definition level & row bounds to determine if is valid
int const is_valid = [&]() {
if (t >= batch_size) {
return 0;
} else if (def) {
int const def_level =
static_cast<int>(def[rolling_index<state_buf::nz_buf_size>(value_count + t)]);
return ((def_level > 0) && in_row_bounds) ? 1 : 0;
}
return in_row_bounds;
}();
int const is_valid =
((t >= batch_size) || !in_row_bounds) ? 0 : ((def[value_count + t] > 0) ? 1 : 0);

// thread and block validity count
using block_scan = cub::BlockScan<int, decode_block_size>;
Expand Down Expand Up @@ -624,20 +602,14 @@ __device__ int update_validity_and_row_indices_lists(int32_t target_value_count,
auto const [def_level, start_depth, end_depth] = [&]() {
if (!within_batch) { return cuda::std::make_tuple(-1, -1, -1); }

int const level_index = rolling_index<state_buf::nz_buf_size>(value_count + t);
int const rep_level = static_cast<int>(rep[level_index]);
auto const rep_level = rep[value_count + t];
int const start_depth = s->nesting_info[rep_level].start_depth;

if constexpr (!nullable) {
return cuda::std::make_tuple(-1, start_depth, max_depth);
} else {
if (def != nullptr) {
int const def_level = static_cast<int>(def[level_index]);
return cuda::std::make_tuple(
def_level, start_depth, s->nesting_info[def_level].end_depth);
} else {
return cuda::std::make_tuple(1, start_depth, max_depth);
}
int const def_level = def[value_count + t];
return cuda::std::make_tuple(def_level, start_depth, s->nesting_info[def_level].end_depth);
}
}();

Expand Down Expand Up @@ -817,33 +789,6 @@ __device__ int update_validity_and_row_indices_lists(int32_t target_value_count,
return max_depth_valid_count;
}

// is the page marked nullable or not
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a common header

__device__ inline bool is_nullable(page_state_s* s)
{
auto const lvl = level_type::DEFINITION;
auto const max_def_level = s->col.max_level[lvl];
return max_def_level > 0;
}

// for a nullable page, check to see if it could have nulls
__device__ inline bool maybe_has_nulls(page_state_s* s)
{
auto const lvl = level_type::DEFINITION;
auto const init_run = s->initial_rle_run[lvl];
// literal runs, lets assume they could hold nulls
if (is_literal_run(init_run)) { return true; }

// repeated run with number of items in the run not equal
// to the rows in the page, assume that means we could have nulls
if (s->page.num_input_values != (init_run >> 1)) { return true; }

auto const lvl_bits = s->col.level_bits[lvl];
auto const run_val = lvl_bits == 0 ? 0 : s->initial_rle_value[lvl];

// the encoded repeated value isn't valid, we have (all) nulls
return run_val != s->col.max_level[lvl];
}

template <typename state_buf, typename thread_group>
inline __device__ void bool_plain_decode(page_state_s* s,
state_buf* sb,
Expand Down Expand Up @@ -891,13 +836,9 @@ template <int decode_block_size_t,
bool has_bools_t,
bool has_nesting_t,
typename level_t,
typename def_decoder_t,
typename rep_decoder_t,
typename dict_stream_t,
typename bool_stream_t>
__device__ void skip_ahead_in_decoding(page_state_s* s,
def_decoder_t& def_decoder,
rep_decoder_t& rep_decoder,
dict_stream_t& dict_stream,
bool_stream_t& bool_stream,
bool bools_are_rle_stream,
Expand All @@ -922,10 +863,7 @@ __device__ void skip_ahead_in_decoding(page_state_s* s,
if constexpr (has_lists_t) {
auto const skipped_leaf_values = s->page.skipped_leaf_values;
if (skipped_leaf_values > 0) {
if (should_process_nulls) {
skip_decode<rolling_buf_size>(def_decoder, skipped_leaf_values, t);
}
processed_count = skip_decode<rolling_buf_size>(rep_decoder, skipped_leaf_values, t);
processed_count = skipped_leaf_values;
if constexpr (has_dict_t) {
skip_decode<rolling_buf_size>(dict_stream, skipped_leaf_values, t);
} else if constexpr (has_bools_t) {
Expand All @@ -937,25 +875,15 @@ __device__ void skip_ahead_in_decoding(page_state_s* s,

// Non-lists
int const first_row = s->first_row;
if (first_row <= 0) { return; }
if (!should_process_nulls) {
processed_count = first_row;
valid_count = first_row;
} else {
while (processed_count < first_row) {
auto to_process = min(rolling_buf_size, first_row - processed_count);
int next_processed_count = processed_count + def_decoder.decode_next(t, to_process);

int num_valids = skip_validity_and_row_indices_nonlist<decode_block_size_t,
level_t,
has_nesting_t,
rolling_buf_size>(
processed_count, next_processed_count, s, def, t);

valid_count += num_valids;
processed_count = next_processed_count;
}
}
if (first_row <= 0) { return; } // Nothing to skip

// Count the number of valids we're skipping.
processed_count = first_row;
valid_count =
!should_process_nulls
? first_row
: skip_validity_and_row_indices_nonlist<decode_block_size_t, level_t, has_nesting_t>(
first_row, s, def, t);

if constexpr (has_dict_t) {
skip_decode<rolling_buf_size>(dict_stream, valid_count, t);
Expand Down Expand Up @@ -1134,45 +1062,21 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
// shared buffer. all shared memory is suballocated out of here
constexpr int rle_run_buffer_bytes =
cudf::util::round_up_unsafe(rle_run_buffer_size * sizeof(rle_run), size_t{16});
constexpr int shared_buf_size =
rle_run_buffer_bytes * (static_cast<int>(has_dict_t) + static_cast<int>(has_bools_t) +
static_cast<int>(has_lists_t) + 1);
constexpr int shared_buf_size = cuda::std::max(
1, rle_run_buffer_bytes * (static_cast<int>(has_dict_t) + static_cast<int>(has_bools_t)));
__shared__ __align__(16) uint8_t shared_buf[shared_buf_size];

// setup all shared memory buffers
int shared_offset = 0;

auto rep_runs = reinterpret_cast<rle_run*>(shared_buf + shared_offset);
if constexpr (has_lists_t) { shared_offset += rle_run_buffer_bytes; }

auto dict_runs = reinterpret_cast<rle_run*>(shared_buf + shared_offset);
if constexpr (has_dict_t) { shared_offset += rle_run_buffer_bytes; }

auto bool_runs = reinterpret_cast<rle_run*>(shared_buf + shared_offset);
if constexpr (has_bools_t) { shared_offset += rle_run_buffer_bytes; }

auto def_runs = reinterpret_cast<rle_run*>(shared_buf + shared_offset);

// initialize the stream decoders (requires values computed in setup_local_page_info)
rle_stream<level_t, decode_block_size_t, rolling_buf_size> def_decoder{def_runs};
// get the level data
level_t* const def = reinterpret_cast<level_t*>(pp->lvl_decode_buf[level_type::DEFINITION]);
if (should_process_nulls) {
def_decoder.init(s->col.level_bits[level_type::DEFINITION],
s->abs_lvl_start[level_type::DEFINITION],
s->abs_lvl_end[level_type::DEFINITION],
def,
s->page.num_input_values);
}

rle_stream<level_t, decode_block_size_t, rolling_buf_size> rep_decoder{rep_runs};
level_t* const rep = reinterpret_cast<level_t*>(pp->lvl_decode_buf[level_type::REPETITION]);
if constexpr (has_lists_t) {
rep_decoder.init(s->col.level_bits[level_type::REPETITION],
s->abs_lvl_start[level_type::REPETITION],
s->abs_lvl_end[level_type::REPETITION],
rep,
s->page.num_input_values);
}

rle_stream<uint32_t, decode_block_size_t, rolling_buf_size> dict_stream{dict_runs};
if constexpr (has_dict_t) {
Expand Down Expand Up @@ -1212,8 +1116,6 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
has_bools_t,
has_nesting_t,
level_t>(s,
def_decoder,
rep_decoder,
dict_stream,
bool_stream,
bools_are_rle_stream,
Expand All @@ -1231,14 +1133,11 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
(s->input_row_count <= last_row)) {
int next_valid_count;
block.sync();
processed_count += min(rolling_buf_size, s->page.num_input_values - processed_count);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for all cases


// only need to process definition levels if this is a nullable column
if (should_process_nulls) {
processed_count += def_decoder.decode_next(t);
block.sync();

if constexpr (has_lists_t) {
rep_decoder.decode_next(t);
block.sync();
next_valid_count =
update_validity_and_row_indices_lists<decode_block_size_t, true, level_t>(
processed_count, s, sb, def, rep, t);
Expand All @@ -1255,16 +1154,13 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
// nz_idx. decode_fixed_width_values would be the only work that happens.
else {
if constexpr (has_lists_t) {
processed_count += rep_decoder.decode_next(t);
block.sync();
next_valid_count =
update_validity_and_row_indices_lists<decode_block_size_t, false, level_t>(
processed_count, s, sb, nullptr, rep, t);
} else {
// direct copy: no nulls, no lists, no need to update validity or row indices
// This ASSUMES that s->row_index_lower_bound is always -1!
// Its purpose is to handle rows than span page boundaries, which only happen for lists.
processed_count += min(rolling_buf_size, s->page.num_input_values - processed_count);
int const capped_target_value_count = min(processed_count, last_row);
if (t == 0) { s->input_row_count = capped_target_value_count; }
next_valid_count = capped_target_value_count;
Expand Down
Loading
Loading