11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ //! Take implementations for [`BoolVector`].
5+ //!
6+ //! This module includes an optimization for small boolean value arrays (typical of dictionary
7+ //! encoding) that avoids element-wise indexing when possible.
8+
9+ use std:: ops:: BitAnd ;
10+ use std:: ops:: Not ;
11+
12+ use vortex_buffer:: BitBuffer ;
413use vortex_dtype:: UnsignedPType ;
14+ use vortex_mask:: Mask ;
515use vortex_vector:: VectorOps ;
616use vortex_vector:: bool:: BoolVector ;
717use vortex_vector:: primitive:: PVector ;
818
919use crate :: take:: Take ;
1020
21+ // TODO(connor): Figure out good numbers for these heuristics.
22+
23+ /// The maximum length of a values array for which we unconditionally apply the optimized take.
24+ const OPTIMIZED_TAKE_MAX_VALUES_LEN : usize = 8 ;
25+
26+ /// The minimum ratio of `indices.len() / values.len()` for which we apply the optimized take.
27+ const OPTIMIZED_TAKE_MIN_RATIO : usize = 2 ;
28+
29+ /// Returns whether to use the optimized take path based on heuristics.
30+ fn should_use_optimized_take ( values_len : usize , indices_len : usize ) -> bool {
31+ values_len <= OPTIMIZED_TAKE_MAX_VALUES_LEN
32+ || indices_len >= OPTIMIZED_TAKE_MIN_RATIO * values_len
33+ }
34+
1135impl < I : UnsignedPType > Take < PVector < I > > for & BoolVector {
1236 type Output = BoolVector ;
1337
1438 fn take ( self , indices : & PVector < I > ) -> BoolVector {
1539 if indices. validity ( ) . all_true ( ) {
40+ // No null indices, delegate to slice implementation.
1641 self . take ( indices. elements ( ) . as_slice ( ) )
1742 } else {
18- take_nullable ( self , indices)
43+ // Has null indices, need to propagate nulls.
44+ take_with_nullable_indices ( self , indices)
1945 }
2046 }
2147}
@@ -24,26 +50,179 @@ impl<I: UnsignedPType> Take<[I]> for &BoolVector {
2450 type Output = BoolVector ;
2551
2652 fn take ( self , indices : & [ I ] ) -> BoolVector {
27- let taken_bits = self . bits ( ) . take ( indices) ;
28- let taken_validity = self . validity ( ) . take ( indices) ;
53+ if should_use_optimized_take ( self . len ( ) , indices. len ( ) ) {
54+ optimized_take ( self , indices, || self . validity ( ) . take ( indices) )
55+ } else {
56+ default_take ( self , indices)
57+ }
58+ }
59+ }
60+
61+ /// Default element-wise take from a slice of indices.
62+ fn default_take < I : UnsignedPType > ( values : & BoolVector , indices : & [ I ] ) -> BoolVector {
63+ let taken_bits = values. bits ( ) . take ( indices) ;
64+ let taken_validity = values. validity ( ) . take ( indices) ;
65+
66+ debug_assert_eq ! ( taken_bits. len( ) , taken_validity. len( ) ) ;
67+
68+ // SAFETY: Both components were taken with the same indices, so they have the same length.
69+ unsafe { BoolVector :: new_unchecked ( taken_bits, taken_validity) }
70+ }
71+
72+ /// Take with nullable indices, propagating nulls from both values and indices.
73+ fn take_with_nullable_indices < I : UnsignedPType > (
74+ values : & BoolVector ,
75+ indices : & PVector < I > ,
76+ ) -> BoolVector {
77+ let indices_slice = indices. elements ( ) . as_slice ( ) ;
78+ let indices_validity = indices. validity ( ) ;
79+
80+ // Validity must combine value validity with index validity.
81+ let compute_validity = || {
82+ values
83+ . validity ( )
84+ . take ( indices_slice)
85+ . bitand ( indices_validity)
86+ } ;
87+
88+ if should_use_optimized_take ( values. len ( ) , indices. len ( ) ) {
89+ optimized_take ( values, indices_slice, compute_validity)
90+ } else {
91+ // We ignore index nullability when taking the bits since the validity mask handles nulls.
92+ let taken_bits = values. bits ( ) . take ( indices_slice) ;
93+ let taken_validity = compute_validity ( ) ;
2994
3095 debug_assert_eq ! ( taken_bits. len( ) , taken_validity. len( ) ) ;
3196
32- // SAFETY: We called take on both components of the vector with the same indices, so the new
33- // components must have the same length.
97+ // SAFETY: Both components were taken with the same indices, so they have the same length.
3498 unsafe { BoolVector :: new_unchecked ( taken_bits, taken_validity) }
3599 }
36100}
37101
38- fn take_nullable < I : UnsignedPType > ( bvector : & BoolVector , indices : & PVector < I > ) -> BoolVector {
39- // We ignore nullability when taking the bits since we can let the `Mask` implementation
40- // determine which elements are null.
41- let taken_bits = bvector. bits ( ) . take ( indices. elements ( ) . as_slice ( ) ) ;
42- let taken_validity = bvector. validity ( ) . take ( indices) ;
102+ // TODO(connor): Use the generic `compare` implementation when that gets implemented.
103+
104+ /// Creates a [`BitBuffer`] where each bit is set iff the corresponding index equals `target`.
105+ fn broadcast_index_comparison < I : UnsignedPType > ( indices : & [ I ] , target : usize ) -> BitBuffer {
106+ BitBuffer :: collect_bool ( indices. len ( ) , |i| {
107+ // SAFETY: `i` is in bounds since `collect_bool` iterates from 0..len.
108+ let index: usize = unsafe { indices. get_unchecked ( i) . as_ ( ) } ;
109+ index == target
110+ } )
111+ }
112+
113+ /// Optimized take for boolean vectors with small value arrays.
114+ ///
115+ /// Since booleans can only be `true` or `false`, we can optimize these specific cases:
116+ ///
117+ /// - All of the values are `true`, so create a [`BoolVector`] with `n` `true`s.
118+ /// - All of the values are `false`, so create a [`BoolVector`] with `n` `false`s.
119+ /// - There is a single `true` value, so compare indices against that index.
120+ /// - There is a single `false` value, so compare indices against that index and negate.
121+ /// - Otherwise, there are multiple `true`s and `false`s in the `values` vector and we must do a
122+ /// normal `take` on it.
123+ ///
124+ /// The `compute_validity` closure computes the output validity mask, allowing callers to handle
125+ /// nullable vs non-nullable indices differently.
126+ fn optimized_take < I : UnsignedPType > (
127+ values : & BoolVector ,
128+ indices : & [ I ] ,
129+ compute_validity : impl FnOnce ( ) -> Mask ,
130+ ) -> BoolVector {
131+ let len = indices. len ( ) ;
132+ let ( trues, falses) = count_true_and_false_positions ( values) ;
133+
134+ let ( taken_bits, taken_validity) = match ( trues, falses) {
135+ // All values are null.
136+ ( Count :: None , Count :: None ) => ( BitBuffer :: new_unset ( len) , Mask :: new_false ( len) ) ,
137+
138+ // No true values exist, so all output bits are false.
139+ ( Count :: None , _) => ( BitBuffer :: new_unset ( len) , compute_validity ( ) ) ,
140+
141+ // No false values exist, so all output bits are true.
142+ ( _, Count :: None ) => ( BitBuffer :: new_set ( len) , compute_validity ( ) ) ,
143+
144+ // Single true value: output bit is set iff index equals the true position.
145+ ( Count :: One ( true_idx) , _) => {
146+ let bits = broadcast_index_comparison ( indices, true_idx) ;
147+ ( bits, compute_validity ( ) )
148+ }
149+
150+ // Single false value: output bit is set iff index does NOT equal the false position.
151+ ( _, Count :: One ( false_idx) ) => {
152+ let bits = broadcast_index_comparison ( indices, false_idx) . not ( ) ;
153+ ( bits, compute_validity ( ) )
154+ }
155+
156+ // Multiple true and false values, so fall back to the default `take`.
157+ ( Count :: More , Count :: More ) => {
158+ let taken_bits = values. bits ( ) . take ( indices) ;
159+ ( taken_bits, compute_validity ( ) )
160+ }
161+ } ;
43162
44163 debug_assert_eq ! ( taken_bits. len( ) , taken_validity. len( ) ) ;
45164
46- // SAFETY: We used the same indices to take from both components, so they should still have the
47- // same length.
165+ // SAFETY: Both components have length `len` (the length of `indices`).
48166 unsafe { BoolVector :: new_unchecked ( taken_bits, taken_validity) }
49167}
168+
169+ /// Represents the count of true or false values found in a boolean vector.
170+ enum Count {
171+ /// No values of this kind were found.
172+ None ,
173+ /// Exactly one value was found at the given index.
174+ One ( usize ) ,
175+ /// Two or more values were found.
176+ More ,
177+ }
178+
179+ /// Scans a boolean vector to determine how many true and false values exist.
180+ ///
181+ /// Returns `(true_count, false_count)` where each is a [`Count`] indicating none, one (with
182+ /// position), or more than one. Null values are skipped. The scan exits early once both counts
183+ /// reach "more than one".
184+ fn count_true_and_false_positions ( values : & BoolVector ) -> ( Count , Count ) {
185+ let bits = values. bits ( ) ;
186+ let validity = values. validity ( ) ;
187+
188+ let mut first_true: Option < usize > = None ;
189+ let mut found_second_true = false ;
190+ let mut first_false: Option < usize > = None ;
191+ let mut found_second_false = false ;
192+
193+ for idx in 0 ..values. len ( ) {
194+ if !validity. value ( idx) {
195+ continue ;
196+ }
197+
198+ if bits. value ( idx) {
199+ if first_true. is_none ( ) {
200+ first_true = Some ( idx) ;
201+ } else {
202+ found_second_true = true ;
203+ }
204+ } else if first_false. is_none ( ) {
205+ first_false = Some ( idx) ;
206+ } else {
207+ found_second_false = true ;
208+ }
209+
210+ if found_second_true && found_second_false {
211+ break ;
212+ }
213+ }
214+
215+ let true_count = match ( first_true, found_second_true) {
216+ ( None , _) => Count :: None ,
217+ ( Some ( idx) , false ) => Count :: One ( idx) ,
218+ ( Some ( _) , true ) => Count :: More ,
219+ } ;
220+
221+ let false_count = match ( first_false, found_second_false) {
222+ ( None , _) => Count :: None ,
223+ ( Some ( idx) , false ) => Count :: One ( idx) ,
224+ ( Some ( _) , true ) => Count :: More ,
225+ } ;
226+
227+ ( true_count, false_count)
228+ }
0 commit comments