Skip to content

Commit da89aa2

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent ba0e97a commit da89aa2

File tree

23 files changed

+361
-34
lines changed

23 files changed

+361
-34
lines changed

vortex-compute/src/arithmetic/pvector.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ where
161161
mod tests {
162162
use vortex_buffer::buffer;
163163
use vortex_mask::Mask;
164-
use vortex_vector::VectorOps;
165164
use vortex_vector::primitive::PVector;
166165

167166
use crate::arithmetic::Arithmetic;
@@ -175,14 +174,16 @@ mod tests {
175174
let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
176175

177176
let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
178-
assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
177+
let expected = PVector::new_valid(buffer![11u32, 22, 33, 44]);
178+
assert_eq!(result, expected);
179179
}
180180

181181
#[test]
182182
fn test_add_scalar() {
183183
let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
184184
let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
185-
assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
185+
let expected = PVector::new_valid(buffer![11u32, 12, 13, 14]);
186+
assert_eq!(result, expected);
186187
}
187188

188189
#[test]
@@ -192,8 +193,8 @@ mod tests {
192193

193194
let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
194195
// Validity is AND'd, so if either side is null, result is null
195-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
196-
assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
196+
let expected = PVector::new(buffer![11u32, 22, 33], Mask::from_iter([true, false, true]));
197+
assert_eq!(result, expected);
197198
}
198199

199200
#[test]
@@ -202,14 +203,16 @@ mod tests {
202203
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
203204

204205
let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
205-
assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
206+
let expected = PVector::new_valid(buffer![9u32, 18, 27, 36]);
207+
assert_eq!(result, expected);
206208
}
207209

208210
#[test]
209211
fn test_sub_scalar() {
210212
let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
211213
let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
212-
assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
214+
let expected = PVector::new_valid(buffer![5u32, 15, 25, 35]);
215+
assert_eq!(result, expected);
213216
}
214217

215218
#[test]
@@ -218,22 +221,23 @@ mod tests {
218221
let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
219222

220223
let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
221-
assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
224+
let expected = PVector::new_valid(buffer![20u32, 60, 120, 200]);
225+
assert_eq!(result, expected);
222226
}
223227

224228
#[test]
225229
fn test_mul_scalar() {
226230
let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
227231
let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
228-
assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
232+
let expected = PVector::new_valid(buffer![10u32, 20, 30, 40]);
233+
assert_eq!(result, expected);
229234
}
230235

231236
#[test]
232237
fn test_scalar_preserves_validity() {
233238
let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
234239
let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
235-
236-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
237-
assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
240+
let expected = PVector::new(buffer![11u32, 12, 13], Mask::from_iter([true, false, true]));
241+
assert_eq!(result, expected);
238242
}
239243
}

vortex-compute/src/comparison/pvector.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ mod tests {
3838
use vortex_buffer::bitbuffer;
3939
use vortex_buffer::buffer;
4040
use vortex_mask::Mask;
41+
use vortex_vector::bool::BoolVector;
4142

4243
use super::*;
4344
use crate::comparison::Equal;
@@ -53,7 +54,8 @@ mod tests {
5354
let right = PVector::new(buffer![1u32, 2, 5, 4], Mask::new_true(4));
5455

5556
let result = Compare::<Equal>::compare(&left, &right);
56-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
57+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
58+
assert_eq!(result, expected);
5759
}
5860

5961
#[test]
@@ -62,7 +64,8 @@ mod tests {
6264
let right = PVector::new(buffer![1u32, 2, 5, 4], Mask::new_true(4));
6365

6466
let result = Compare::<NotEqual>::compare(&left, &right);
65-
assert_eq!(result.bits(), &bitbuffer![0 0 1 0]);
67+
let expected = BoolVector::new(bitbuffer![0 0 1 0], Mask::new_true(4));
68+
assert_eq!(result, expected);
6669
}
6770

6871
#[test]
@@ -71,7 +74,8 @@ mod tests {
7174
let right = PVector::new(buffer![2u32, 2, 1, 5], Mask::new_true(4));
7275

7376
let result = Compare::<LessThan>::compare(&left, &right);
74-
assert_eq!(result.bits(), &bitbuffer![1 0 0 1]);
77+
let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
78+
assert_eq!(result, expected);
7579
}
7680

7781
#[test]
@@ -80,7 +84,8 @@ mod tests {
8084
let right = PVector::new(buffer![2u32, 2, 1, 5], Mask::new_true(4));
8185

8286
let result = Compare::<LessThanOrEqual>::compare(&left, &right);
83-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
87+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
88+
assert_eq!(result, expected);
8489
}
8590

8691
#[test]
@@ -89,7 +94,8 @@ mod tests {
8994
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
9095

9196
let result = Compare::<GreaterThan>::compare(&left, &right);
92-
assert_eq!(result.bits(), &bitbuffer![1 0 0 1]);
97+
let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
98+
assert_eq!(result, expected);
9399
}
94100

95101
#[test]
@@ -98,7 +104,8 @@ mod tests {
98104
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
99105

100106
let result = Compare::<GreaterThanOrEqual>::compare(&left, &right);
101-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
107+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
108+
assert_eq!(result, expected);
102109
}
103110

104111
#[test]
@@ -108,7 +115,7 @@ mod tests {
108115

109116
let result = Compare::<Equal>::compare(&left, &right);
110117
// Validity is AND'd, so if either side is null, result validity is null
111-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
112-
assert_eq!(result.bits(), &bitbuffer![1 1 1]);
118+
let expected = BoolVector::new(bitbuffer![1 1 1], Mask::from_iter([true, false, true]));
119+
assert_eq!(result, expected);
113120
}
114121
}

vortex-vector/src/binaryview/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::binaryview::BinaryViewVectorMut;
1111
use crate::binaryview::StringType;
1212

1313
/// A scalar value for types that implement [`BinaryViewType`].
14-
#[derive(Clone, Debug)]
14+
#[derive(Clone, Debug, PartialEq)]
1515
pub struct BinaryViewScalar<T: BinaryViewType>(Option<T::Scalar>);
1616

1717
impl<T: BinaryViewType> BinaryViewScalar<T> {

vortex-vector/src/binaryview/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub trait BinaryViewType: Debug + Sized + Send + Sync + 'static + private::Seale
6363
}
6464

6565
/// [`BinaryType`] for UTF-8 strings.
66-
#[derive(Clone, Debug)]
66+
#[derive(Clone, Debug, PartialEq, Eq)]
6767
pub struct StringType;
6868
impl BinaryViewType for StringType {
6969
type Slice = str;
@@ -96,7 +96,7 @@ impl BinaryViewType for StringType {
9696
}
9797

9898
/// [`BinaryType`] for raw binary data.
99-
#[derive(Clone, Debug)]
99+
#[derive(Clone, Debug, PartialEq, Eq)]
100100
pub struct BinaryType;
101101
impl BinaryViewType for BinaryType {
102102
type Slice = [u8];

vortex-vector/src/binaryview/vector.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,49 @@ pub struct BinaryViewVector<T: BinaryViewType> {
3838
_marker: std::marker::PhantomData<T>,
3939
}
4040

41+
impl<T: BinaryViewType> PartialEq for BinaryViewVector<T> {
42+
fn eq(&self, other: &Self) -> bool {
43+
if self.views.len() != other.views.len() {
44+
return false;
45+
}
46+
// Validity patterns must match
47+
if self.validity != other.validity {
48+
return false;
49+
}
50+
// Compare all views, OR with !validity to ignore invalid positions
51+
self.views
52+
.iter()
53+
.zip(other.views.iter())
54+
.enumerate()
55+
.all(|(i, (self_view, other_view))| {
56+
// If invalid, treat as equal
57+
if !self.validity.value(i) {
58+
return true;
59+
}
60+
// For valid elements, compare the actual byte content via the view
61+
let self_bytes: &[u8] = if self_view.is_inlined() {
62+
self_view.as_inlined().value()
63+
} else {
64+
let view_ref = self_view.as_view();
65+
let buffer = &self.buffers[view_ref.buffer_index as usize];
66+
&buffer[view_ref.as_range()]
67+
};
68+
69+
let other_bytes: &[u8] = if other_view.is_inlined() {
70+
other_view.as_inlined().value()
71+
} else {
72+
let view_ref = other_view.as_view();
73+
let buffer = &other.buffers[view_ref.buffer_index as usize];
74+
&buffer[view_ref.as_range()]
75+
};
76+
77+
self_bytes == other_bytes
78+
})
79+
}
80+
}
81+
82+
impl<T: BinaryViewType> Eq for BinaryViewVector<T> {}
83+
4184
impl<T: BinaryViewType> BinaryViewVector<T> {
4285
/// Creates a new [`BinaryViewVector`] from the provided components.
4386
///

vortex-vector/src/bool/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::VectorMutOps;
88
use crate::bool::BoolVectorMut;
99

1010
/// A scalar value for boolean types.
11-
#[derive(Clone, Debug)]
11+
#[derive(Clone, Debug, PartialEq, Eq)]
1212
pub struct BoolScalar(Option<bool>);
1313

1414
impl BoolScalar {

vortex-vector/src/bool/vector.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,44 @@ use crate::bool::BoolVectorMut;
2020
/// An immutable vector of boolean values.
2121
///
2222
/// Internally, this `BoolVector` is a wrapper around a [`BitBuffer`] and a validity mask.
23-
#[derive(Debug, Clone)]
23+
#[derive(Debug, Clone, Eq)]
2424
pub struct BoolVector {
2525
/// The bits that we use to represent booleans.
2626
pub(super) bits: BitBuffer,
2727
/// The validity mask (where `true` represents an element is **not** null).
2828
pub(super) validity: Mask,
2929
}
3030

31+
impl PartialEq for BoolVector {
32+
fn eq(&self, other: &Self) -> bool {
33+
if self.len() != other.len() {
34+
return false;
35+
}
36+
// Validity patterns must match
37+
if self.validity != other.validity {
38+
return false;
39+
}
40+
// Use XNOR comparison: bits are equal where !(lhs ^ rhs) is true
41+
let lhs_chunks = self.bits.chunks();
42+
let rhs_chunks = other.bits.chunks();
43+
let validity_bits = self.validity.to_bit_buffer();
44+
let validity_chunks = validity_bits.chunks();
45+
46+
// For equality: check that !(lhs ^ rhs) & validity == validity at each chunk
47+
for ((lhs, rhs), valid) in lhs_chunks
48+
.iter_padded()
49+
.zip(rhs_chunks.iter_padded())
50+
.zip(validity_chunks.iter_padded())
51+
{
52+
let equal_bits = !(lhs ^ rhs); // XNOR: true where bits are equal
53+
if (equal_bits & valid) != valid {
54+
return false;
55+
}
56+
}
57+
true
58+
}
59+
}
60+
3161
impl BoolVector {
3262
/// Creates a new [`BoolVector`] from the given bits and validity mask.
3363
///

vortex-vector/src/decimal/generic.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ pub struct DVector<D> {
3636
pub(super) validity: Mask,
3737
}
3838

39+
impl<D: NativeDecimalType + PartialEq> PartialEq for DVector<D> {
40+
fn eq(&self, other: &Self) -> bool {
41+
if self.elements.len() != other.elements.len() {
42+
return false;
43+
}
44+
// Precision and scale must match
45+
if self.ps != other.ps {
46+
return false;
47+
}
48+
// Validity patterns must match
49+
if self.validity != other.validity {
50+
return false;
51+
}
52+
// Compare all elements, OR with !validity to ignore invalid positions
53+
self.elements
54+
.iter()
55+
.zip(other.elements.iter())
56+
.enumerate()
57+
.all(|(i, (a, b))| !self.validity.value(i) || a == b)
58+
}
59+
}
60+
61+
impl<D: NativeDecimalType + Eq> Eq for DVector<D> {}
62+
3963
impl<D: NativeDecimalType> DVector<D> {
4064
/// Creates a new [`DVector<D>`] from the given [`PrecisionScale`], elements buffer, and
4165
/// validity mask.

vortex-vector/src/decimal/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::VectorMutOps;
1414
use crate::decimal::DVectorMut;
1515

1616
/// Represents a decimal scalar value.
17-
#[derive(Clone, Debug)]
17+
#[derive(Clone, Debug, PartialEq, Eq)]
1818
pub enum DecimalScalar {
1919
/// 8-bit decimal scalar.
2020
D8(DScalar<i8>),

vortex-vector/src/decimal/vector.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::decimal::DecimalVectorMut;
2121
use crate::match_each_dvector;
2222

2323
/// An enum over all supported decimal mutable vector types.
24-
#[derive(Clone, Debug)]
24+
#[derive(Clone, Debug, Eq)]
2525
pub enum DecimalVector {
2626
/// A decimal vector with 8-bit integer representation.
2727
D8(DVector<i8>),
@@ -37,6 +37,20 @@ pub enum DecimalVector {
3737
D256(DVector<i256>),
3838
}
3939

40+
impl PartialEq for DecimalVector {
41+
fn eq(&self, other: &Self) -> bool {
42+
match (self, other) {
43+
(DecimalVector::D8(a), DecimalVector::D8(b)) => a == b,
44+
(DecimalVector::D16(a), DecimalVector::D16(b)) => a == b,
45+
(DecimalVector::D32(a), DecimalVector::D32(b)) => a == b,
46+
(DecimalVector::D64(a), DecimalVector::D64(b)) => a == b,
47+
(DecimalVector::D128(a), DecimalVector::D128(b)) => a == b,
48+
(DecimalVector::D256(a), DecimalVector::D256(b)) => a == b,
49+
_ => false,
50+
}
51+
}
52+
}
53+
4054
impl DecimalVector {
4155
/// Returns the precision of the decimal vector.
4256
pub fn precision(&self) -> u8 {

0 commit comments

Comments
 (0)