Skip to content

Commit 2c0af3b

Browse files
committed
patches has same dtype as array
1 parent ef3a022 commit 2c0af3b

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

encodings/alp/src/alp/array.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,24 @@ impl ALPArray {
4848
let mut children = Vec::with_capacity(2);
4949
children.push(encoded);
5050
if let Some(patches) = &patches {
51-
if patches.dtype().is_nullable() {
52-
vortex_bail!(MismatchedTypes: "patches should be non-nullable", patches.dtype());
51+
if patches.dtype() != &dtype {
52+
vortex_bail!(MismatchedTypes: dtype, patches.dtype());
5353
}
54+
55+
if !matches!(
56+
patches.values().logical_validity(),
57+
LogicalValidity::AllValid(_)
58+
) {
59+
vortex_bail!("ALPArray: patches must not contain invalid entries");
60+
}
61+
5462
children.push(patches.indices().clone());
5563
children.push(patches.values().clone());
5664
}
5765

5866
let patches = patches
5967
.as_ref()
60-
.map(|p| p.to_metadata(length, &dtype.as_nonnullable()))
68+
.map(|p| p.to_metadata(length, &dtype))
6169
.transpose()?;
6270

6371
Self::try_from_parts(
@@ -96,7 +104,7 @@ impl ALPArray {
96104
.child(1, &p.indices_dtype(), p.len())
97105
.vortex_expect("ALPArray: patch indices"),
98106
self.as_ref()
99-
.child(2, &self.dtype().as_nonnullable(), p.len())
107+
.child(2, &self.dtype(), p.len())
100108
.vortex_expect("ALPArray: patch values"),
101109
)
102110
})

encodings/alp/src/alp/compress.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use vortex_array::array::PrimitiveArray;
22
use vortex_array::patches::Patches;
3+
use vortex_array::validity::Validity;
34
use vortex_array::variants::PrimitiveArrayTrait;
45
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
56
use vortex_dtype::{NativePType, PType};
@@ -36,12 +37,21 @@ where
3637
let (exponents, encoded, exc_pos, exc) =
3738
T::encode(values.as_slice::<T>(), &values.validity(), exponents)?;
3839
let len = encoded.len();
40+
let patches_validity = if values.dtype().is_nullable() {
41+
Validity::AllValid
42+
} else {
43+
Validity::NonNullable
44+
};
3945
Ok((
4046
exponents,
4147
PrimitiveArray::new(encoded, values.validity()).into_array(),
4248
(!exc.is_empty()).then(|| {
4349
let position_arr = exc_pos.into_array();
44-
Patches::new(len, position_arr, exc.into_array())
50+
Patches::new(
51+
len,
52+
position_arr,
53+
PrimitiveArray::new(exc, patches_validity).into_array(),
54+
)
4555
}),
4656
))
4757
}

0 commit comments

Comments
 (0)