Skip to content

Commit 1ebacee

Browse files
committed
optimize bool take
Signed-off-by: Connor Tsui <[email protected]>
1 parent 32f66e5 commit 1ebacee

File tree

1 file changed

+191
-12
lines changed
  • vortex-compute/src/take/vector

1 file changed

+191
-12
lines changed
Lines changed: 191 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
//! Take implementations for [`BoolVector`].
5+
//!
6+
//! This module includes an optimization for small boolean value arrays (typical of dictionary
7+
//! encoding) that avoids element-wise indexing when possible.
8+
9+
use std::ops::BitAnd;
10+
use std::ops::Not;
11+
12+
use vortex_buffer::BitBuffer;
413
use vortex_dtype::UnsignedPType;
14+
use vortex_mask::Mask;
515
use vortex_vector::VectorOps;
616
use vortex_vector::bool::BoolVector;
717
use vortex_vector::primitive::PVector;
818

919
use crate::take::Take;
1020

21+
// TODO(connor): Figure out good numbers for these heuristics.
22+
23+
/// The maximum length of a values array for which we unconditionally apply the optimized take.
24+
const OPTIMIZED_TAKE_MAX_VALUES_LEN: usize = 8;
25+
26+
/// The minimum ratio of `indices.len() / values.len()` for which we apply the optimized take.
27+
const OPTIMIZED_TAKE_MIN_RATIO: usize = 2;
28+
29+
/// Returns whether to use the optimized take path based on heuristics.
30+
fn should_use_optimized_take(values_len: usize, indices_len: usize) -> bool {
31+
values_len <= OPTIMIZED_TAKE_MAX_VALUES_LEN
32+
|| indices_len >= OPTIMIZED_TAKE_MIN_RATIO * values_len
33+
}
34+
1135
impl<I: UnsignedPType> Take<PVector<I>> for &BoolVector {
1236
type Output = BoolVector;
1337

1438
fn take(self, indices: &PVector<I>) -> BoolVector {
1539
if indices.validity().all_true() {
40+
// No null indices, delegate to slice implementation.
1641
self.take(indices.elements().as_slice())
1742
} else {
18-
take_nullable(self, indices)
43+
// Has null indices, need to propagate nulls.
44+
take_with_nullable_indices(self, indices)
1945
}
2046
}
2147
}
@@ -24,26 +50,179 @@ impl<I: UnsignedPType> Take<[I]> for &BoolVector {
2450
type Output = BoolVector;
2551

2652
fn take(self, indices: &[I]) -> BoolVector {
27-
let taken_bits = self.bits().take(indices);
28-
let taken_validity = self.validity().take(indices);
53+
if should_use_optimized_take(self.len(), indices.len()) {
54+
optimized_take(self, indices, || self.validity().take(indices))
55+
} else {
56+
default_take(self, indices)
57+
}
58+
}
59+
}
60+
61+
/// Default element-wise take from a slice of indices.
62+
fn default_take<I: UnsignedPType>(values: &BoolVector, indices: &[I]) -> BoolVector {
63+
let taken_bits = values.bits().take(indices);
64+
let taken_validity = values.validity().take(indices);
65+
66+
debug_assert_eq!(taken_bits.len(), taken_validity.len());
67+
68+
// SAFETY: Both components were taken with the same indices, so they have the same length.
69+
unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
70+
}
71+
72+
/// Take with nullable indices, propagating nulls from both values and indices.
73+
fn take_with_nullable_indices<I: UnsignedPType>(
74+
values: &BoolVector,
75+
indices: &PVector<I>,
76+
) -> BoolVector {
77+
let indices_slice = indices.elements().as_slice();
78+
let indices_validity = indices.validity();
79+
80+
// Validity must combine value validity with index validity.
81+
let compute_validity = || {
82+
values
83+
.validity()
84+
.take(indices_slice)
85+
.bitand(indices_validity)
86+
};
87+
88+
if should_use_optimized_take(values.len(), indices.len()) {
89+
optimized_take(values, indices_slice, compute_validity)
90+
} else {
91+
// We ignore index nullability when taking the bits since the validity mask handles nulls.
92+
let taken_bits = values.bits().take(indices_slice);
93+
let taken_validity = compute_validity();
2994

3095
debug_assert_eq!(taken_bits.len(), taken_validity.len());
3196

32-
// SAFETY: We called take on both components of the vector with the same indices, so the new
33-
// components must have the same length.
97+
// SAFETY: Both components were taken with the same indices, so they have the same length.
3498
unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
3599
}
36100
}
37101

38-
fn take_nullable<I: UnsignedPType>(bvector: &BoolVector, indices: &PVector<I>) -> BoolVector {
39-
// We ignore nullability when taking the bits since we can let the `Mask` implementation
40-
// determine which elements are null.
41-
let taken_bits = bvector.bits().take(indices.elements().as_slice());
42-
let taken_validity = bvector.validity().take(indices);
102+
// TODO(connor): Use the generic `compare` implementation when that gets implemented.
103+
104+
/// Creates a [`BitBuffer`] where each bit is set iff the corresponding index equals `target`.
105+
fn broadcast_index_comparison<I: UnsignedPType>(indices: &[I], target: usize) -> BitBuffer {
106+
BitBuffer::collect_bool(indices.len(), |i| {
107+
// SAFETY: `i` is in bounds since `collect_bool` iterates from 0..len.
108+
let index: usize = unsafe { indices.get_unchecked(i).as_() };
109+
index == target
110+
})
111+
}
112+
113+
/// Optimized take for boolean vectors with small value arrays.
114+
///
115+
/// Since booleans can only be `true` or `false`, we can optimize these specific cases:
116+
///
117+
/// - All of the values are `true`, so create a [`BoolVector`] with `n` `true`s.
118+
/// - All of the values are `false`, so create a [`BoolVector`] with `n` `false`s.
119+
/// - There is a single `true` value, so compare indices against that index.
120+
/// - There is a single `false` value, so compare indices against that index and negate.
121+
/// - Otherwise, there are multiple `true`s and `false`s in the `values` vector and we must do a
122+
/// normal `take` on it.
123+
///
124+
/// The `compute_validity` closure computes the output validity mask, allowing callers to handle
125+
/// nullable vs non-nullable indices differently.
126+
fn optimized_take<I: UnsignedPType>(
127+
values: &BoolVector,
128+
indices: &[I],
129+
compute_validity: impl FnOnce() -> Mask,
130+
) -> BoolVector {
131+
let len = indices.len();
132+
let (trues, falses) = count_true_and_false_positions(values);
133+
134+
let (taken_bits, taken_validity) = match (trues, falses) {
135+
// All values are null.
136+
(Count::None, Count::None) => (BitBuffer::new_unset(len), Mask::new_false(len)),
137+
138+
// No true values exist, so all output bits are false.
139+
(Count::None, _) => (BitBuffer::new_unset(len), compute_validity()),
140+
141+
// No false values exist, so all output bits are true.
142+
(_, Count::None) => (BitBuffer::new_set(len), compute_validity()),
143+
144+
// Single true value: output bit is set iff index equals the true position.
145+
(Count::One(true_idx), _) => {
146+
let bits = broadcast_index_comparison(indices, true_idx);
147+
(bits, compute_validity())
148+
}
149+
150+
// Single false value: output bit is set iff index does NOT equal the false position.
151+
(_, Count::One(false_idx)) => {
152+
let bits = broadcast_index_comparison(indices, false_idx).not();
153+
(bits, compute_validity())
154+
}
155+
156+
// Multiple true and false values, so fall back to the default `take`.
157+
(Count::More, Count::More) => {
158+
let taken_bits = values.bits().take(indices);
159+
(taken_bits, compute_validity())
160+
}
161+
};
43162

44163
debug_assert_eq!(taken_bits.len(), taken_validity.len());
45164

46-
// SAFETY: We used the same indices to take from both components, so they should still have the
47-
// same length.
165+
// SAFETY: Both components have length `len` (the length of `indices`).
48166
unsafe { BoolVector::new_unchecked(taken_bits, taken_validity) }
49167
}
168+
169+
/// Represents the count of true or false values found in a boolean vector.
170+
enum Count {
171+
/// No values of this kind were found.
172+
None,
173+
/// Exactly one value was found at the given index.
174+
One(usize),
175+
/// Two or more values were found.
176+
More,
177+
}
178+
179+
/// Scans a boolean vector to determine how many true and false values exist.
180+
///
181+
/// Returns `(true_count, false_count)` where each is a [`Count`] indicating none, one (with
182+
/// position), or more than one. Null values are skipped. The scan exits early once both counts
183+
/// reach "more than one".
184+
fn count_true_and_false_positions(values: &BoolVector) -> (Count, Count) {
185+
let bits = values.bits();
186+
let validity = values.validity();
187+
188+
let mut first_true: Option<usize> = None;
189+
let mut found_second_true = false;
190+
let mut first_false: Option<usize> = None;
191+
let mut found_second_false = false;
192+
193+
for idx in 0..values.len() {
194+
if !validity.value(idx) {
195+
continue;
196+
}
197+
198+
if bits.value(idx) {
199+
if first_true.is_none() {
200+
first_true = Some(idx);
201+
} else {
202+
found_second_true = true;
203+
}
204+
} else if first_false.is_none() {
205+
first_false = Some(idx);
206+
} else {
207+
found_second_false = true;
208+
}
209+
210+
if found_second_true && found_second_false {
211+
break;
212+
}
213+
}
214+
215+
let true_count = match (first_true, found_second_true) {
216+
(None, _) => Count::None,
217+
(Some(idx), false) => Count::One(idx),
218+
(Some(_), true) => Count::More,
219+
};
220+
221+
let false_count = match (first_false, found_second_false) {
222+
(None, _) => Count::None,
223+
(Some(idx), false) => Count::One(idx),
224+
(Some(_), true) => Count::More,
225+
};
226+
227+
(true_count, false_count)
228+
}

0 commit comments

Comments
 (0)