Skip to content

Commit 4c329f5

Browse files
authored
Vector comparison compute (#5100)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent fed289c commit 4c329f5

File tree

9 files changed

+510
-8
lines changed

9 files changed

+510
-8
lines changed

vortex-compute/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ vortex-vector = { workspace = true }
2828
num-traits = { workspace = true }
2929

3030
[features]
31-
default = ["arithmetic", "filter", "logical"]
31+
default = ["arithmetic", "comparison", "filter", "logical"]
3232

3333
arithmetic = []
34+
comparison = []
3435
filter = []
3536
logical = []
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ops::BitAnd;
5+
6+
use vortex_buffer::{BitBuffer, BufferMut};
7+
use vortex_vector::{BoolVector, VectorOps};
8+
9+
use crate::comparison::{
10+
Compare, Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
11+
};
12+
13+
impl<Op> Compare<Op> for &BoolVector
14+
where
15+
Op: BitComparisonOperator,
16+
{
17+
type Output = BoolVector;
18+
19+
fn compare(self, rhs: Self) -> Self::Output {
20+
let validity = self.validity().bitand(rhs.validity());
21+
22+
let lhs = self.bits().chunks();
23+
let rhs = rhs.bits().chunks();
24+
25+
// Reserve one extra chunk to account for partial padding chunk at the end.
26+
let mut buffer = BufferMut::<u64>::with_capacity(lhs.chunk_len() + 1);
27+
buffer.extend(
28+
lhs.iter_padded()
29+
.zip(rhs.iter_padded())
30+
.map(|(a_chunk, b_chunk)| Op::apply(&a_chunk, &b_chunk)),
31+
);
32+
let bits = BitBuffer::new(buffer.freeze().into_byte_buffer(), self.len());
33+
34+
BoolVector::new(bits, validity)
35+
}
36+
}
37+
38+
pub trait BitComparisonOperator {
39+
fn apply(a: &u64, b: &u64) -> u64;
40+
}
41+
42+
impl BitComparisonOperator for Equal {
43+
fn apply(a: &u64, b: &u64) -> u64 {
44+
!(a ^ b)
45+
}
46+
}
47+
impl BitComparisonOperator for NotEqual {
48+
fn apply(a: &u64, b: &u64) -> u64 {
49+
a ^ b
50+
}
51+
}
52+
impl BitComparisonOperator for LessThan {
53+
fn apply(a: &u64, b: &u64) -> u64 {
54+
(!a) & b
55+
}
56+
}
57+
impl BitComparisonOperator for LessThanOrEqual {
58+
fn apply(a: &u64, b: &u64) -> u64 {
59+
!(a & (!b))
60+
}
61+
}
62+
impl BitComparisonOperator for GreaterThan {
63+
fn apply(a: &u64, b: &u64) -> u64 {
64+
a & (!b)
65+
}
66+
}
67+
impl BitComparisonOperator for GreaterThanOrEqual {
68+
fn apply(a: &u64, b: &u64) -> u64 {
69+
!((!a) & b)
70+
}
71+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::BitBuffer;
5+
6+
use crate::comparison::{Compare, ComparisonOperator};
7+
8+
/// Adapter to implement `Compare` for any `ComparableCollection`.
9+
pub(crate) struct ComparableCollectionAdapter<C>(pub C);
10+
11+
impl<Op, C> Compare<Op> for ComparableCollectionAdapter<C>
12+
where
13+
C: ComparableCollection,
14+
Op: ComparisonOperator<C::Item>,
15+
{
16+
type Output = BitBuffer;
17+
18+
fn compare(self, rhs: Self) -> Self::Output {
19+
assert_eq!(self.0.len(), rhs.0.len());
20+
21+
BitBuffer::from_iter((0..self.0.len()).map(|i| {
22+
let left = unsafe { self.0.item_unchecked(i) };
23+
let right = unsafe { rhs.0.item_unchecked(i) };
24+
Op::apply(&left, &right)
25+
}))
26+
}
27+
}
28+
29+
/// Marker trait for comparable collections.
30+
pub trait ComparableCollection {
31+
/// The item type that can be compared.
32+
type Item;
33+
34+
/// Get the length of the comparable collection.
35+
fn len(&self) -> usize;
36+
37+
/// Get the item at the specified index without bounds checking.
38+
unsafe fn item_unchecked(&self, index: usize) -> Self::Item;
39+
}
40+
41+
impl<T: Copy> ComparableCollection for &[T] {
42+
type Item = T;
43+
44+
fn len(&self) -> usize {
45+
<[T]>::len(self)
46+
}
47+
48+
unsafe fn item_unchecked(&self, index: usize) -> Self::Item {
49+
unsafe { *self.get_unchecked(index) }
50+
}
51+
}
52+
53+
impl<Op, T> Compare<Op> for &[T]
54+
where
55+
T: Copy,
56+
Op: ComparisonOperator<T>,
57+
{
58+
type Output = BitBuffer;
59+
60+
fn compare(self, rhs: Self) -> Self::Output {
61+
Compare::<Op>::compare(
62+
ComparableCollectionAdapter(self),
63+
ComparableCollectionAdapter(rhs),
64+
)
65+
}
66+
}
67+
68+
#[cfg(test)]
69+
mod tests {
70+
use vortex_buffer::bitbuffer;
71+
72+
use super::*;
73+
use crate::comparison::{Equal, GreaterThan, LessThan, NotEqual};
74+
75+
#[test]
76+
fn test_slice_equal() {
77+
let left: &[u32] = &[1, 2, 3, 4];
78+
let right: &[u32] = &[1, 2, 5, 4];
79+
80+
let result = Compare::<Equal>::compare(left, right);
81+
assert_eq!(result, bitbuffer![1 1 0 1]);
82+
}
83+
84+
#[test]
85+
fn test_slice_not_equal() {
86+
let left: &[u32] = &[1, 2, 3, 4];
87+
let right: &[u32] = &[1, 2, 5, 4];
88+
89+
let result = Compare::<NotEqual>::compare(left, right);
90+
assert_eq!(result, bitbuffer![0 0 1 0]);
91+
}
92+
93+
#[test]
94+
fn test_slice_less_than() {
95+
let left: &[u32] = &[1, 2, 3, 4];
96+
let right: &[u32] = &[2, 2, 1, 5];
97+
98+
let result = Compare::<LessThan>::compare(left, right);
99+
assert_eq!(result, bitbuffer![1 0 0 1]);
100+
}
101+
102+
#[test]
103+
fn test_slice_greater_than() {
104+
let left: &[u32] = &[3, 2, 1, 5];
105+
let right: &[u32] = &[1, 2, 3, 4];
106+
107+
let result = Compare::<GreaterThan>::compare(left, right);
108+
assert_eq!(result, bitbuffer![1 0 0 1]);
109+
}
110+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Comparison operations for Vortex vectors.
5+
6+
use vortex_dtype::half::f16;
7+
8+
mod bool;
9+
mod collection;
10+
mod pvector;
11+
12+
/// Trait for comparison operations.
13+
pub trait Compare<Op, Rhs = Self> {
14+
/// The result type after performing the operation.
15+
type Output;
16+
17+
/// Perform the comparison operation.
18+
fn compare(self, rhs: Rhs) -> Self::Output;
19+
}
20+
21+
/// Trait for comparison operators.
22+
pub trait ComparisonOperator<T> {
23+
/// Apply the operator to the two operands.
24+
fn apply(a: &T, b: &T) -> bool;
25+
}
26+
27+
/// A marker type for equality comparison operations.
28+
pub struct Equal;
29+
/// A marker type for inequality comparison operations.
30+
pub struct NotEqual;
31+
/// A marker type for less-than comparison operations.
32+
pub struct LessThan;
33+
/// A marker type for less-than-or-equal comparison operations.
34+
pub struct LessThanOrEqual;
35+
/// A marker type for greater-than comparison operations.
36+
pub struct GreaterThan;
37+
/// A marker type for greater-than-or-equal comparison operations.
38+
pub struct GreaterThanOrEqual;
39+
40+
/// Marker trait for comparable items.
41+
pub trait ComparableItem {
42+
/// Check if two items are equal.
43+
fn is_equal(lhs: &Self, rhs: &Self) -> bool;
44+
45+
/// Check if the `lhs` item is less than the `rhs` item.
46+
fn is_less_than(lhs: &Self, rhs: &Self) -> bool;
47+
}
48+
49+
impl<T: ComparableItem> ComparisonOperator<T> for Equal {
50+
fn apply(a: &T, b: &T) -> bool {
51+
T::is_equal(a, b)
52+
}
53+
}
54+
55+
impl<T: ComparableItem> ComparisonOperator<T> for NotEqual {
56+
fn apply(a: &T, b: &T) -> bool {
57+
!T::is_equal(a, b)
58+
}
59+
}
60+
61+
impl<T: ComparableItem> ComparisonOperator<T> for LessThan {
62+
fn apply(a: &T, b: &T) -> bool {
63+
T::is_less_than(a, b)
64+
}
65+
}
66+
67+
impl<T: ComparableItem> ComparisonOperator<T> for GreaterThanOrEqual {
68+
fn apply(a: &T, b: &T) -> bool {
69+
!T::is_less_than(a, b)
70+
}
71+
}
72+
73+
impl<T: ComparableItem> ComparisonOperator<T> for GreaterThan {
74+
fn apply(a: &T, b: &T) -> bool {
75+
T::is_less_than(b, a)
76+
}
77+
}
78+
79+
impl<T: ComparableItem> ComparisonOperator<T> for LessThanOrEqual {
80+
fn apply(a: &T, b: &T) -> bool {
81+
!T::is_less_than(b, a)
82+
}
83+
}
84+
85+
macro_rules! impl_integer {
86+
($T:ty) => {
87+
impl ComparableItem for $T {
88+
#[inline(always)]
89+
fn is_equal(lhs: &Self, rhs: &Self) -> bool {
90+
lhs == rhs
91+
}
92+
93+
#[inline(always)]
94+
fn is_less_than(lhs: &Self, rhs: &Self) -> bool {
95+
lhs < rhs
96+
}
97+
}
98+
};
99+
}
100+
101+
impl_integer!(i8);
102+
impl_integer!(i16);
103+
impl_integer!(i32);
104+
impl_integer!(i64);
105+
impl_integer!(i128);
106+
impl_integer!(u8);
107+
impl_integer!(u16);
108+
impl_integer!(u32);
109+
impl_integer!(u64);
110+
impl_integer!(u128);
111+
112+
macro_rules! impl_float {
113+
($T:ty) => {
114+
impl ComparableItem for $T {
115+
#[inline(always)]
116+
fn is_equal(lhs: &Self, rhs: &Self) -> bool {
117+
lhs.to_bits().eq(&rhs.to_bits())
118+
}
119+
120+
#[inline(always)]
121+
fn is_less_than(lhs: &Self, rhs: &Self) -> bool {
122+
lhs.total_cmp(rhs).is_lt()
123+
}
124+
}
125+
};
126+
}
127+
128+
impl_float!(f16);
129+
impl_float!(f32);
130+
impl_float!(f64);

0 commit comments

Comments
 (0)