Skip to content

Commit 4da2c72

Browse files
feat[compute]: compare decimal and varbin (#5757)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 3f9e024 commit 4da2c72

File tree

11 files changed

+932
-8
lines changed

11 files changed

+932
-8
lines changed

vortex-buffer/src/bit/buf.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::ops::RangeBounds;
1111
use crate::Alignment;
1212
use crate::BitBufferMut;
1313
use crate::Buffer;
14+
use crate::BufferMut;
1415
use crate::ByteBuffer;
1516
use crate::bit::BitChunks;
1617
use crate::bit::BitIndexIterator;
@@ -136,6 +137,55 @@ impl BitBuffer {
136137
BitBufferMut::collect_bool(len, f).freeze()
137138
}
138139

140+
/// Maps over each bit in this buffer, calling `f(index, bit_value)` and collecting results.
141+
///
142+
/// This is more efficient than `collect_bool` when you need to read the current bit value,
143+
/// as it unpacks each u64 chunk only once rather than doing random access for each bit.
144+
pub fn map_cmp<F>(&self, mut f: F) -> Self
145+
where
146+
F: FnMut(usize, bool) -> bool,
147+
{
148+
let len = self.len;
149+
let mut buffer: BufferMut<u64> = BufferMut::with_capacity(len.div_ceil(64));
150+
151+
let chunks_count = len / 64;
152+
let remainder = len % 64;
153+
let chunks = self.chunks();
154+
155+
for (chunk_idx, src_chunk) in chunks.iter().enumerate() {
156+
let mut packed = 0u64;
157+
for bit_idx in 0..64 {
158+
let i = bit_idx + chunk_idx * 64;
159+
let bit_value = (src_chunk >> bit_idx) & 1 == 1;
160+
packed |= (f(i, bit_value) as u64) << bit_idx;
161+
}
162+
163+
// SAFETY: Already allocated sufficient capacity
164+
unsafe { buffer.push_unchecked(packed) }
165+
}
166+
167+
if remainder != 0 {
168+
let src_chunk = chunks.remainder_bits();
169+
let mut packed = 0u64;
170+
for bit_idx in 0..remainder {
171+
let i = bit_idx + chunks_count * 64;
172+
let bit_value = (src_chunk >> bit_idx) & 1 == 1;
173+
packed |= (f(i, bit_value) as u64) << bit_idx;
174+
}
175+
176+
// SAFETY: Already allocated sufficient capacity
177+
unsafe { buffer.push_unchecked(packed) }
178+
}
179+
180+
buffer.truncate(len.div_ceil(8));
181+
182+
Self {
183+
buffer: buffer.freeze().into_byte_buffer(),
184+
offset: 0,
185+
len,
186+
}
187+
}
188+
139189
/// Clear all bits in the buffer, preserving existing capacity.
140190
pub fn clear(&mut self) {
141191
self.buffer.clear();
@@ -636,4 +686,55 @@ mod tests {
636686
);
637687
}
638688
}
689+
690+
#[rstest]
691+
#[case(5)]
692+
#[case(8)]
693+
#[case(10)]
694+
#[case(64)]
695+
#[case(65)]
696+
#[case(100)]
697+
#[case(128)]
698+
fn test_map_cmp_identity(#[case] len: usize) {
699+
// map_cmp with identity function should return the same buffer
700+
let buf = BitBuffer::collect_bool(len, |i| i % 3 == 0);
701+
let mapped = buf.map_cmp(|_idx, bit| bit);
702+
703+
assert_eq!(buf.len(), mapped.len());
704+
for i in 0..len {
705+
assert_eq!(buf.value(i), mapped.value(i), "Mismatch at index {}", i);
706+
}
707+
}
708+
709+
#[rstest]
710+
#[case(5)]
711+
#[case(8)]
712+
#[case(64)]
713+
#[case(65)]
714+
#[case(100)]
715+
fn test_map_cmp_negate(#[case] len: usize) {
716+
// map_cmp negating all bits
717+
let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
718+
let mapped = buf.map_cmp(|_idx, bit| !bit);
719+
720+
assert_eq!(buf.len(), mapped.len());
721+
for i in 0..len {
722+
assert_eq!(!buf.value(i), mapped.value(i), "Mismatch at index {}", i);
723+
}
724+
}
725+
726+
#[test]
727+
fn test_map_cmp_conditional() {
728+
// map_cmp with conditional logic based on index and bit value
729+
let len = 100;
730+
let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
731+
732+
// Only keep bits that are set AND at even index divisible by 4
733+
let mapped = buf.map_cmp(|idx, bit| bit && idx % 4 == 0);
734+
735+
for i in 0..len {
736+
let expected = (i % 2 == 0) && (i % 4 == 0);
737+
assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
738+
}
739+
}
639740
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Compare implementations for BinaryViewScalar.
5+
6+
use vortex_vector::binaryview::BinaryViewScalar;
7+
use vortex_vector::binaryview::BinaryViewType;
8+
use vortex_vector::bool::BoolScalar;
9+
10+
use crate::comparison::Compare;
11+
use crate::comparison::Equal;
12+
use crate::comparison::GreaterThan;
13+
use crate::comparison::GreaterThanOrEqual;
14+
use crate::comparison::LessThan;
15+
use crate::comparison::LessThanOrEqual;
16+
use crate::comparison::NotEqual;
17+
18+
/// Compare two BinaryViewScalars using the provided comparison function.
19+
fn compare_binaryview_scalar<T: BinaryViewType, F>(
20+
lhs: BinaryViewScalar<T>,
21+
rhs: BinaryViewScalar<T>,
22+
cmp: F,
23+
) -> BoolScalar
24+
where
25+
F: Fn(&[u8], &[u8]) -> bool,
26+
{
27+
match (lhs.value(), rhs.value()) {
28+
(Some(l), Some(r)) => {
29+
let l_bytes: &[u8] = AsRef::<T::Slice>::as_ref(l).as_ref();
30+
let r_bytes: &[u8] = AsRef::<T::Slice>::as_ref(r).as_ref();
31+
BoolScalar::new(Some(cmp(l_bytes, r_bytes)))
32+
}
33+
_ => BoolScalar::new(None),
34+
}
35+
}
36+
37+
impl<T: BinaryViewType> Compare<Equal> for BinaryViewScalar<T> {
38+
type Output = BoolScalar;
39+
40+
fn compare(self, rhs: Self) -> Self::Output {
41+
compare_binaryview_scalar(self, rhs, |l, r| l == r)
42+
}
43+
}
44+
45+
impl<T: BinaryViewType> Compare<NotEqual> for BinaryViewScalar<T> {
46+
type Output = BoolScalar;
47+
48+
fn compare(self, rhs: Self) -> Self::Output {
49+
compare_binaryview_scalar(self, rhs, |l, r| l != r)
50+
}
51+
}
52+
53+
impl<T: BinaryViewType> Compare<LessThan> for BinaryViewScalar<T> {
54+
type Output = BoolScalar;
55+
56+
fn compare(self, rhs: Self) -> Self::Output {
57+
compare_binaryview_scalar(self, rhs, |l, r| l < r)
58+
}
59+
}
60+
61+
impl<T: BinaryViewType> Compare<LessThanOrEqual> for BinaryViewScalar<T> {
62+
type Output = BoolScalar;
63+
64+
fn compare(self, rhs: Self) -> Self::Output {
65+
compare_binaryview_scalar(self, rhs, |l, r| l <= r)
66+
}
67+
}
68+
69+
impl<T: BinaryViewType> Compare<GreaterThan> for BinaryViewScalar<T> {
70+
type Output = BoolScalar;
71+
72+
fn compare(self, rhs: Self) -> Self::Output {
73+
compare_binaryview_scalar(self, rhs, |l, r| l > r)
74+
}
75+
}
76+
77+
impl<T: BinaryViewType> Compare<GreaterThanOrEqual> for BinaryViewScalar<T> {
78+
type Output = BoolScalar;
79+
80+
fn compare(self, rhs: Self) -> Self::Output {
81+
compare_binaryview_scalar(self, rhs, |l, r| l >= r)
82+
}
83+
}
84+
85+
#[cfg(test)]
86+
mod tests {
87+
use vortex_buffer::BufferString;
88+
use vortex_vector::binaryview::StringType;
89+
90+
use super::*;
91+
92+
#[test]
93+
fn test_string_scalar_equal() {
94+
let left = BinaryViewScalar::<StringType>::new(Some(BufferString::from("hello")));
95+
let right = BinaryViewScalar::<StringType>::new(Some(BufferString::from("hello")));
96+
97+
assert_eq!(Compare::<Equal>::compare(left, right).value(), Some(true));
98+
}
99+
100+
#[test]
101+
fn test_string_scalar_not_equal() {
102+
let left = BinaryViewScalar::<StringType>::new(Some(BufferString::from("hello")));
103+
let right = BinaryViewScalar::<StringType>::new(Some(BufferString::from("world")));
104+
105+
assert_eq!(
106+
Compare::<NotEqual>::compare(left, right).value(),
107+
Some(true)
108+
);
109+
}
110+
111+
#[test]
112+
fn test_string_scalar_less_than() {
113+
let left = BinaryViewScalar::<StringType>::new(Some(BufferString::from("apple")));
114+
let right = BinaryViewScalar::<StringType>::new(Some(BufferString::from("banana")));
115+
116+
assert_eq!(
117+
Compare::<LessThan>::compare(left, right).value(),
118+
Some(true)
119+
);
120+
}
121+
122+
#[test]
123+
fn test_string_scalar_with_null() {
124+
let left = BinaryViewScalar::<StringType>::new(Some(BufferString::from("hello")));
125+
let right = BinaryViewScalar::<StringType>::new(None);
126+
127+
assert_eq!(Compare::<Equal>::compare(left, right).value(), None);
128+
}
129+
}

0 commit comments

Comments
 (0)