Skip to content

Commit 4e28866

Browse files
feat[vector]: Eq (#5681)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 96dce46 commit 4e28866

File tree

23 files changed

+1016
-34
lines changed

23 files changed

+1016
-34
lines changed

vortex-compute/src/arithmetic/pvector.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ where
161161
mod tests {
162162
use vortex_buffer::buffer;
163163
use vortex_mask::Mask;
164-
use vortex_vector::VectorOps;
165164
use vortex_vector::primitive::PVector;
166165

167166
use crate::arithmetic::Arithmetic;
@@ -175,14 +174,16 @@ mod tests {
175174
let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
176175

177176
let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
178-
assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
177+
let expected = PVector::from(buffer![11u32, 22, 33, 44]);
178+
assert_eq!(result, expected);
179179
}
180180

181181
#[test]
182182
fn test_add_scalar() {
183183
let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
184184
let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
185-
assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
185+
let expected = PVector::from(buffer![11u32, 12, 13, 14]);
186+
assert_eq!(result, expected);
186187
}
187188

188189
#[test]
@@ -192,8 +193,8 @@ mod tests {
192193

193194
let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
194195
// Validity is AND'd, so if either side is null, result is null
195-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
196-
assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
196+
let expected = PVector::new(buffer![11u32, 22, 33], Mask::from_iter([true, false, true]));
197+
assert_eq!(result, expected);
197198
}
198199

199200
#[test]
@@ -202,14 +203,16 @@ mod tests {
202203
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
203204

204205
let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
205-
assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
206+
let expected = PVector::from(buffer![9u32, 18, 27, 36]);
207+
assert_eq!(result, expected);
206208
}
207209

208210
#[test]
209211
fn test_sub_scalar() {
210212
let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
211213
let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
212-
assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
214+
let expected = PVector::from(buffer![5u32, 15, 25, 35]);
215+
assert_eq!(result, expected);
213216
}
214217

215218
#[test]
@@ -218,22 +221,23 @@ mod tests {
218221
let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
219222

220223
let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
221-
assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
224+
let expected = PVector::from(buffer![20u32, 60, 120, 200]);
225+
assert_eq!(result, expected);
222226
}
223227

224228
#[test]
225229
fn test_mul_scalar() {
226230
let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
227231
let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
228-
assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
232+
let expected = PVector::from(buffer![10u32, 20, 30, 40]);
233+
assert_eq!(result, expected);
229234
}
230235

231236
#[test]
232237
fn test_scalar_preserves_validity() {
233238
let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
234239
let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
235-
236-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
237-
assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
240+
let expected = PVector::new(buffer![11u32, 12, 13], Mask::from_iter([true, false, true]));
241+
assert_eq!(result, expected);
238242
}
239243
}

vortex-compute/src/comparison/pvector.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ mod tests {
3838
use vortex_buffer::bitbuffer;
3939
use vortex_buffer::buffer;
4040
use vortex_mask::Mask;
41+
use vortex_vector::bool::BoolVector;
4142

4243
use super::*;
4344
use crate::comparison::Equal;
@@ -53,7 +54,8 @@ mod tests {
5354
let right = PVector::new(buffer![1u32, 2, 5, 4], Mask::new_true(4));
5455

5556
let result = Compare::<Equal>::compare(&left, &right);
56-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
57+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
58+
assert_eq!(result, expected);
5759
}
5860

5961
#[test]
@@ -62,7 +64,8 @@ mod tests {
6264
let right = PVector::new(buffer![1u32, 2, 5, 4], Mask::new_true(4));
6365

6466
let result = Compare::<NotEqual>::compare(&left, &right);
65-
assert_eq!(result.bits(), &bitbuffer![0 0 1 0]);
67+
let expected = BoolVector::new(bitbuffer![0 0 1 0], Mask::new_true(4));
68+
assert_eq!(result, expected);
6669
}
6770

6871
#[test]
@@ -71,7 +74,8 @@ mod tests {
7174
let right = PVector::new(buffer![2u32, 2, 1, 5], Mask::new_true(4));
7275

7376
let result = Compare::<LessThan>::compare(&left, &right);
74-
assert_eq!(result.bits(), &bitbuffer![1 0 0 1]);
77+
let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
78+
assert_eq!(result, expected);
7579
}
7680

7781
#[test]
@@ -80,7 +84,8 @@ mod tests {
8084
let right = PVector::new(buffer![2u32, 2, 1, 5], Mask::new_true(4));
8185

8286
let result = Compare::<LessThanOrEqual>::compare(&left, &right);
83-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
87+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
88+
assert_eq!(result, expected);
8489
}
8590

8691
#[test]
@@ -89,7 +94,8 @@ mod tests {
8994
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
9095

9196
let result = Compare::<GreaterThan>::compare(&left, &right);
92-
assert_eq!(result.bits(), &bitbuffer![1 0 0 1]);
97+
let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
98+
assert_eq!(result, expected);
9399
}
94100

95101
#[test]
@@ -98,7 +104,8 @@ mod tests {
98104
let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
99105

100106
let result = Compare::<GreaterThanOrEqual>::compare(&left, &right);
101-
assert_eq!(result.bits(), &bitbuffer![1 1 0 1]);
107+
let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
108+
assert_eq!(result, expected);
102109
}
103110

104111
#[test]
@@ -108,7 +115,7 @@ mod tests {
108115

109116
let result = Compare::<Equal>::compare(&left, &right);
110117
// Validity is AND'd, so if either side is null, result validity is null
111-
assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
112-
assert_eq!(result.bits(), &bitbuffer![1 1 1]);
118+
let expected = BoolVector::new(bitbuffer![1 1 1], Mask::from_iter([true, false, true]));
119+
assert_eq!(result, expected);
113120
}
114121
}

vortex-vector/src/binaryview/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::binaryview::BinaryViewVectorMut;
1212
use crate::binaryview::StringType;
1313

1414
/// A scalar value for types that implement [`BinaryViewType`].
15-
#[derive(Clone, Debug)]
15+
#[derive(Clone, Debug, PartialEq)]
1616
pub struct BinaryViewScalar<T: BinaryViewType>(Option<T::Scalar>);
1717

1818
impl<T: BinaryViewType> BinaryViewScalar<T> {

vortex-vector/src/binaryview/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ pub trait BinaryViewType: Debug + Sized + Send + Sync + 'static + private::Seale
7070
}
7171

7272
/// [`BinaryViewType`] for UTF-8 strings.
73-
#[derive(Clone, Debug)]
73+
#[derive(Clone, Debug, PartialEq, Eq)]
7474
pub struct StringType;
7575
impl BinaryViewType for StringType {
7676
type Slice = str;
@@ -111,7 +111,7 @@ impl BinaryViewType for StringType {
111111
}
112112

113113
/// [`BinaryViewType`] for raw binary data.
114-
#[derive(Clone, Debug)]
114+
#[derive(Clone, Debug, PartialEq, Eq)]
115115
pub struct BinaryType;
116116
impl BinaryViewType for BinaryType {
117117
type Slice = [u8];

vortex-vector/src/binaryview/vector.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,49 @@ pub struct BinaryViewVector<T: BinaryViewType> {
3838
_marker: std::marker::PhantomData<T>,
3939
}
4040

41+
impl<T: BinaryViewType> PartialEq for BinaryViewVector<T> {
42+
fn eq(&self, other: &Self) -> bool {
43+
if self.views.len() != other.views.len() {
44+
return false;
45+
}
46+
// Validity patterns must match
47+
if self.validity != other.validity {
48+
return false;
49+
}
50+
// Compare all views, OR with !validity to ignore invalid positions
51+
self.views
52+
.iter()
53+
.zip(other.views.iter())
54+
.enumerate()
55+
.all(|(i, (self_view, other_view))| {
56+
// If invalid, treat as equal
57+
if !self.validity.value(i) {
58+
return true;
59+
}
60+
// For valid elements, compare the actual byte content via the view
61+
let self_bytes: &[u8] = if self_view.is_inlined() {
62+
self_view.as_inlined().value()
63+
} else {
64+
let view_ref = self_view.as_view();
65+
let buffer = &self.buffers[view_ref.buffer_index as usize];
66+
&buffer[view_ref.as_range()]
67+
};
68+
69+
let other_bytes: &[u8] = if other_view.is_inlined() {
70+
other_view.as_inlined().value()
71+
} else {
72+
let view_ref = other_view.as_view();
73+
let buffer = &other.buffers[view_ref.buffer_index as usize];
74+
&buffer[view_ref.as_range()]
75+
};
76+
77+
self_bytes == other_bytes
78+
})
79+
}
80+
}
81+
82+
impl<T: BinaryViewType> Eq for BinaryViewVector<T> {}
83+
4184
impl<T: BinaryViewType> BinaryViewVector<T> {
4285
/// Creates a new [`BinaryViewVector`] from the provided components.
4386
///
@@ -388,4 +431,150 @@ mod tests {
388431

389432
assert!(shared_vec.try_into_mut().is_ok());
390433
}
434+
435+
#[test]
436+
fn test_binaryview_eq_identical_inlined() {
437+
// Test equality with inlined strings (<=12 bytes).
438+
let mut v1 = StringVectorMut::with_capacity(3);
439+
v1.append_values("hello", 1);
440+
v1.append_values("world", 1);
441+
v1.append_values("test", 1);
442+
let v1 = v1.freeze();
443+
444+
let mut v2 = StringVectorMut::with_capacity(3);
445+
v2.append_values("hello", 1);
446+
v2.append_values("world", 1);
447+
v2.append_values("test", 1);
448+
let v2 = v2.freeze();
449+
450+
assert_eq!(v1, v2);
451+
}
452+
453+
#[test]
454+
fn test_binaryview_eq_identical_outlined() {
455+
// Test equality with outlined strings (>12 bytes).
456+
let mut v1 = StringVectorMut::with_capacity(2);
457+
v1.append_values("this is a longer string that won't be inlined", 1);
458+
v1.append_values("another long string for testing purposes", 1);
459+
let v1 = v1.freeze();
460+
461+
let mut v2 = StringVectorMut::with_capacity(2);
462+
v2.append_values("this is a longer string that won't be inlined", 1);
463+
v2.append_values("another long string for testing purposes", 1);
464+
let v2 = v2.freeze();
465+
466+
assert_eq!(v1, v2);
467+
}
468+
469+
#[test]
470+
fn test_binaryview_eq_different_length() {
471+
let mut v1 = StringVectorMut::with_capacity(3);
472+
v1.append_values("a", 1);
473+
v1.append_values("b", 1);
474+
v1.append_values("c", 1);
475+
let v1 = v1.freeze();
476+
477+
let mut v2 = StringVectorMut::with_capacity(2);
478+
v2.append_values("a", 1);
479+
v2.append_values("b", 1);
480+
let v2 = v2.freeze();
481+
482+
assert_ne!(v1, v2);
483+
}
484+
485+
#[test]
486+
fn test_binaryview_eq_different_validity() {
487+
let mut v1 = StringVectorMut::with_capacity(3);
488+
v1.append_values("a", 1);
489+
v1.append_values("b", 1);
490+
v1.append_values("c", 1);
491+
let v1 = v1.freeze();
492+
493+
let mut v2 = StringVectorMut::with_capacity(3);
494+
v2.append_values("a", 1);
495+
v2.append_nulls(1);
496+
v2.append_values("c", 1);
497+
let v2 = v2.freeze();
498+
499+
assert_ne!(v1, v2);
500+
}
501+
502+
#[test]
503+
fn test_binaryview_eq_different_values() {
504+
let mut v1 = StringVectorMut::with_capacity(3);
505+
v1.append_values("hello", 1);
506+
v1.append_values("world", 1);
507+
v1.append_values("test", 1);
508+
let v1 = v1.freeze();
509+
510+
let mut v2 = StringVectorMut::with_capacity(3);
511+
v2.append_values("hello", 1);
512+
v2.append_values("DIFFERENT", 1);
513+
v2.append_values("test", 1);
514+
let v2 = v2.freeze();
515+
516+
assert_ne!(v1, v2);
517+
}
518+
519+
#[test]
520+
fn test_binaryview_eq_ignores_invalid_positions_inlined() {
521+
// Two vectors with different values at invalid positions should be equal.
522+
let mut v1 = StringVectorMut::with_capacity(3);
523+
v1.append_values("hello", 1);
524+
v1.append_values("value_a", 1); // This will be masked as invalid
525+
v1.append_values("test", 1);
526+
let mut v1 = v1.freeze();
527+
// Mask position 1 as invalid
528+
v1.mask_validity(&Mask::from_iter([true, false, true]));
529+
530+
let mut v2 = StringVectorMut::with_capacity(3);
531+
v2.append_values("hello", 1);
532+
v2.append_values("value_b", 1); // Different value at invalid position
533+
v2.append_values("test", 1);
534+
let mut v2 = v2.freeze();
535+
v2.mask_validity(&Mask::from_iter([true, false, true]));
536+
537+
assert_eq!(v1, v2);
538+
}
539+
540+
#[test]
541+
fn test_binaryview_eq_ignores_invalid_positions_outlined() {
542+
// Test with outlined strings at invalid positions.
543+
let mut v1 = StringVectorMut::with_capacity(3);
544+
v1.append_values("this is a very long string that will be outlined", 1);
545+
v1.append_values("another long value that differs between vectors A", 1);
546+
v1.append_values("yet another long string for the test", 1);
547+
let mut v1 = v1.freeze();
548+
v1.mask_validity(&Mask::from_iter([true, false, true]));
549+
550+
let mut v2 = StringVectorMut::with_capacity(3);
551+
v2.append_values("this is a very long string that will be outlined", 1);
552+
v2.append_values("different long value at the invalid position B", 1);
553+
v2.append_values("yet another long string for the test", 1);
554+
let mut v2 = v2.freeze();
555+
v2.mask_validity(&Mask::from_iter([true, false, true]));
556+
557+
assert_eq!(v1, v2);
558+
}
559+
560+
#[test]
561+
fn test_binaryview_eq_empty() {
562+
let v1 = StringVectorMut::with_capacity(0).freeze();
563+
let v2 = StringVectorMut::with_capacity(0).freeze();
564+
565+
assert_eq!(v1, v2);
566+
}
567+
568+
#[test]
569+
fn test_binaryview_eq_all_nulls() {
570+
let mut v1 = StringVectorMut::with_capacity(3);
571+
v1.append_nulls(3);
572+
let v1 = v1.freeze();
573+
574+
let mut v2 = StringVectorMut::with_capacity(3);
575+
v2.append_nulls(3);
576+
let v2 = v2.freeze();
577+
578+
assert_eq!(v1, v2);
579+
}
391580
}

0 commit comments

Comments
 (0)