Skip to content

Commit 2c4d537

Browse files
committed
pull patch filter to alp_encode_components
1 parent 2c0af3b commit 2c4d537

File tree

3 files changed

+41
-101
lines changed

3 files changed

+41
-101
lines changed

encodings/alp/src/alp/compress.rs

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use vortex_array::array::PrimitiveArray;
22
use vortex_array::patches::Patches;
3-
use vortex_array::validity::Validity;
3+
use vortex_array::validity::{ArrayValidity as _, Validity};
44
use vortex_array::variants::PrimitiveArrayTrait;
55
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
6+
use vortex_buffer::Buffer;
67
use vortex_dtype::{NativePType, PType};
78
use vortex_error::{vortex_bail, VortexResult};
89
use vortex_scalar::ScalarType;
@@ -25,6 +26,7 @@ macro_rules! match_each_alp_float_ptype {
2526
})
2627
}
2728

29+
#[allow(clippy::cast_possible_truncation)]
2830
pub fn alp_encode_components<T>(
2931
values: &PrimitiveArray,
3032
exponents: Option<Exponents>,
@@ -34,14 +36,24 @@ where
3436
T::ALPInt: NativePType,
3537
T: ScalarType,
3638
{
37-
let (exponents, encoded, exc_pos, exc) =
38-
T::encode(values.as_slice::<T>(), &values.validity(), exponents)?;
39+
let (exponents, encoded, exc_pos, exc) = T::encode(values.as_slice::<T>(), exponents);
3940
let len = encoded.len();
4041
let patches_validity = if values.dtype().is_nullable() {
4142
Validity::AllValid
4243
} else {
4344
Validity::NonNullable
4445
};
46+
let exc_pos = exc_pos
47+
.into_iter()
48+
.filter(|index| values.is_valid(*index as usize))
49+
.collect::<Buffer<u64>>();
50+
let exc = exc
51+
.into_iter()
52+
.enumerate()
53+
.filter(|(index, _)| values.is_valid(*index))
54+
.map(|x| x.1)
55+
.collect::<Buffer<T>>();
56+
4557
Ok((
4658
exponents,
4759
PrimitiveArray::new(encoded, values.validity()).into_array(),
@@ -154,39 +166,6 @@ mod tests {
154166
assert_eq!(values.as_slice(), decoded.as_slice::<f64>());
155167
}
156168

157-
#[test]
158-
#[allow(clippy::approx_constant)] // ALP doesn't like E
159-
fn test_compress_ignores_invalid_exceptional_values() {
160-
let values = buffer![1.234f64, 2.718, f64::consts::PI, 4.0];
161-
let array = PrimitiveArray::new(values, Validity::from_iter([true, true, false, true]));
162-
let encoded = alp_encode(&array).unwrap();
163-
assert!(encoded.patches().is_none());
164-
assert_eq!(
165-
encoded
166-
.encoded()
167-
.into_primitive()
168-
.unwrap()
169-
.as_slice::<i64>(),
170-
vec![1234i64, 2718, 3142, 4000] // fill forward
171-
);
172-
assert_eq!(encoded.exponents(), Exponents { e: 16, f: 13 });
173-
174-
let decoded = decompress(encoded).unwrap();
175-
assert_eq!(
176-
scalar_at(&decoded, 0).unwrap(),
177-
scalar_at(&array, 0).unwrap()
178-
);
179-
assert_eq!(
180-
scalar_at(&decoded, 1).unwrap(),
181-
scalar_at(&array, 1).unwrap()
182-
);
183-
assert!(!decoded.is_valid(2));
184-
assert_eq!(
185-
scalar_at(&decoded, 3).unwrap(),
186-
scalar_at(&array, 3).unwrap()
187-
);
188-
}
189-
190169
#[test]
191170
#[allow(clippy::approx_constant)] // ALP doesn't like E
192171
fn test_nullable_patched_scalar_at() {

encodings/alp/src/alp/mod.rs

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@ mod compute;
1111

1212
pub use array::*;
1313
pub use compress::*;
14-
use vortex_array::array::PrimitiveArray;
15-
use vortex_array::validity::Validity;
16-
use vortex_array::IntoArrayData as _;
1714
use vortex_buffer::{Buffer, BufferMut};
18-
use vortex_error::VortexResult;
1915

2016
const SAMPLE_SIZE: usize = 32;
2117

@@ -59,47 +55,24 @@ pub trait ALPFloat: private::Sealed + Float + Display + 'static {
5955
/// Convert from the integer type back to the float type using `as`.
6056
fn from_int(n: Self::ALPInt) -> Self;
6157

62-
fn sampled_find_best_exponents(
63-
values: &[Self],
64-
validity: &Validity,
65-
) -> VortexResult<Exponents> {
66-
if values.len() <= SAMPLE_SIZE {
67-
Self::find_best_exponents(values, validity)
68-
} else {
69-
let validity = validity.take(
70-
&PrimitiveArray::from_iter(
71-
(0..values.len())
72-
.step_by(values.len() / SAMPLE_SIZE)
73-
.map(|x| x as u64)
74-
.take(SAMPLE_SIZE),
75-
)
76-
.into_array(),
77-
)?;
78-
let values = values
79-
.iter()
80-
.step_by(values.len() / SAMPLE_SIZE)
81-
.take(SAMPLE_SIZE)
82-
.cloned()
83-
.collect_vec();
84-
Self::find_best_exponents(&values, &validity)
85-
}
86-
}
87-
88-
fn find_best_exponents(values: &[Self], validity: &Validity) -> VortexResult<Exponents> {
58+
fn find_best_exponents(values: &[Self]) -> Exponents {
8959
let mut best_exp = Exponents { e: 0, f: 0 };
9060
let mut best_nbytes: usize = usize::MAX;
9161

92-
assert!(
93-
values.len() <= SAMPLE_SIZE,
94-
"{} <= {}",
95-
values.len(),
96-
SAMPLE_SIZE
97-
);
62+
let sample = (values.len() > SAMPLE_SIZE).then(|| {
63+
values
64+
.iter()
65+
.step_by(values.len() / SAMPLE_SIZE)
66+
.cloned()
67+
.collect_vec()
68+
});
9869

9970
for e in (0..Self::MAX_EXPONENT).rev() {
10071
for f in 0..e {
101-
let (_, encoded, _, exc_patches) =
102-
Self::encode(values, validity, Some(Exponents { e, f }))?;
72+
let (_, encoded, _, exc_patches) = Self::encode(
73+
sample.as_deref().unwrap_or(values),
74+
Some(Exponents { e, f }),
75+
);
10376

10477
let size = Self::estimate_encoded_size(&encoded, &exc_patches);
10578
if size < best_nbytes {
@@ -111,7 +84,7 @@ pub trait ALPFloat: private::Sealed + Float + Display + 'static {
11184
}
11285
}
11386

114-
Ok(best_exp)
87+
best_exp
11588
}
11689

11790
#[inline]
@@ -139,16 +112,11 @@ pub trait ALPFloat: private::Sealed + Float + Display + 'static {
139112
encoded_bytes + patch_bytes
140113
}
141114

142-
#[allow(clippy::type_complexity)]
143115
fn encode(
144116
values: &[Self],
145-
validity: &Validity,
146117
exponents: Option<Exponents>,
147-
) -> VortexResult<(Exponents, Buffer<Self::ALPInt>, Buffer<u64>, Buffer<Self>)> {
148-
let exponents = match exponents {
149-
Some(exponents) => exponents,
150-
None => Self::sampled_find_best_exponents(values, validity)?,
151-
};
118+
) -> (Exponents, Buffer<Self::ALPInt>, Buffer<u64>, Buffer<Self>) {
119+
let exp = exponents.unwrap_or_else(|| Self::find_best_exponents(values));
152120

153121
let mut encoded_output = BufferMut::<Self::ALPInt>::with_capacity(values.len());
154122
let mut patch_indices = BufferMut::<u64>::with_capacity(values.len());
@@ -161,21 +129,20 @@ pub trait ALPFloat: private::Sealed + Float + Display + 'static {
161129
for chunk in values.chunks(encode_chunk_size) {
162130
encode_chunk_unchecked(
163131
chunk,
164-
exponents,
132+
exp,
165133
&mut encoded_output,
166134
&mut patch_indices,
167135
&mut patch_values,
168136
&mut fill_value,
169-
validity,
170137
);
171138
}
172139

173-
Ok((
174-
exponents,
140+
(
141+
exp,
175142
encoded_output.freeze(),
176143
patch_indices.freeze(),
177144
patch_values.freeze(),
178-
))
145+
)
179146
}
180147

181148
#[inline]
@@ -224,7 +191,6 @@ fn encode_chunk_unchecked<T: ALPFloat>(
224191
patch_indices: &mut BufferMut<u64>,
225192
patch_values: &mut BufferMut<T>,
226193
fill_value: &mut Option<T::ALPInt>,
227-
validity: &Validity,
228194
) {
229195
let num_prev_encoded = encoded_output.len();
230196
let num_prev_patches = patch_indices.len();
@@ -258,13 +224,12 @@ fn encode_chunk_unchecked<T: ALPFloat>(
258224
// write() is only safe to call more than once because the values are primitive (i.e., Drop is a no-op)
259225
patch_indices_mut[chunk_patch_index].write(i as u64);
260226
patch_values_mut[chunk_patch_index].write(chunk[i - num_prev_encoded]);
261-
let is_valid_and_an_exception =
262-
(decoded != chunk[i - num_prev_encoded]) && validity.is_valid(i);
263-
chunk_patch_index += is_valid_and_an_exception as usize;
227+
chunk_patch_index += (decoded != chunk[i - num_prev_encoded]) as usize;
264228
}
229+
assert_eq!(chunk_patch_index, chunk_patch_count);
265230
unsafe {
266-
patch_indices.set_len(num_prev_patches + chunk_patch_index);
267-
patch_values.set_len(num_prev_patches + chunk_patch_index);
231+
patch_indices.set_len(num_prev_patches + chunk_patch_count);
232+
patch_values.set_len(num_prev_patches + chunk_patch_count);
268233
}
269234
}
270235

vortex-array/src/array/primitive/patch.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,9 @@ impl PrimitiveArray {
1515
let patch_indices = patch_indices.into_primitive()?;
1616
let patch_values = patch_values.into_primitive()?;
1717

18-
let patched_validity = match patch_values.validity() {
19-
Validity::NonNullable => self.validity(),
20-
patch_validity => {
21-
self.validity()
22-
.patch(self.len(), patch_indices.as_ref(), patch_validity)?
23-
}
24-
};
18+
let patched_validity =
19+
self.validity()
20+
.patch(self.len(), patch_indices.as_ref(), patch_values.validity())?;
2521

2622
match_each_integer_ptype!(patch_indices.ptype(), |$I| {
2723
match_each_native_ptype!(self.ptype(), |$T| {

0 commit comments

Comments
 (0)