Skip to content

Commit a29972f

Browse files
authored
Nullability of take (#2341)
1 parent a8cfcb7 commit a29972f

File tree

43 files changed

+695
-568
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+695
-568
lines changed

docs/rust/quickstart.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Use :func:`~vortex.compress` to compress the Vortex array and check the relative
4545

4646
>>> cvtx = vx.compress(vtx)
4747
>>> cvtx.nbytes
48-
15385
48+
15061
4949
>>> cvtx.nbytes / vtx.nbytes
5050
0.10...
5151

encodings/alp/src/alp/array.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ impl ValidityVTable<ALPArray> for ALPEncoding {
140140
array.encoded().all_valid()
141141
}
142142

143+
fn all_invalid(&self, array: &ALPArray) -> VortexResult<bool> {
144+
array.encoded().all_invalid()
145+
}
146+
143147
fn validity_mask(&self, array: &ALPArray) -> VortexResult<Mask> {
144148
array.encoded().validity_mask()
145149
}

encodings/alp/src/alp_rd/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ impl ValidityVTable<ALPRDArray> for ALPRDEncoding {
240240
array.left_parts().all_valid()
241241
}
242242

243+
fn all_invalid(&self, array: &ALPRDArray) -> VortexResult<bool> {
244+
// Use validity from left_parts
245+
array.left_parts().all_invalid()
246+
}
247+
243248
fn validity_mask(&self, array: &ALPRDArray) -> VortexResult<Mask> {
244249
// Use validity from left_parts
245250
array.left_parts().validity_mask()

encodings/alp/src/alp_rd/compute/take.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,15 @@ impl TakeFn<ALPRDArray> for ALPRDEncoding {
2121
p.cast_values(&values_dtype)
2222
})
2323
.transpose()?;
24-
2524
let right_parts = fill_null(
2625
take(array.right_parts(), indices)?,
2726
Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
2827
)?;
28+
2929
Ok(ALPRDArray::try_new(
30-
if taken_left_parts.dtype().is_nullable() {
31-
array.dtype().as_nullable()
32-
} else {
33-
array.dtype().clone()
34-
},
30+
array
31+
.dtype()
32+
.with_nullability(taken_left_parts.dtype().nullability()),
3533
taken_left_parts,
3634
array.left_parts_dict(),
3735
right_parts,

encodings/bytebool/src/array.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ impl ValidityVTable<ByteBoolArray> for ByteBoolEncoding {
122122
array.validity().all_valid()
123123
}
124124

125+
fn all_invalid(&self, array: &ByteBoolArray) -> VortexResult<bool> {
126+
array.validity().all_invalid()
127+
}
128+
125129
fn validity_mask(&self, array: &ByteBoolArray) -> VortexResult<Mask> {
126130
array.validity().to_logical(array.len())
127131
}

encodings/datetime-parts/src/array.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ impl ValidityVTable<DateTimePartsArray> for DateTimePartsEncoding {
144144
array.days().all_valid()
145145
}
146146

147+
fn all_invalid(&self, array: &DateTimePartsArray) -> VortexResult<bool> {
148+
array.days().all_invalid()
149+
}
150+
147151
fn validity_mask(&self, array: &DateTimePartsArray) -> VortexResult<Mask> {
148152
array.days().validity_mask()
149153
}

encodings/dict/benches/dict_compare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinArray, VarBinView
1111
use vortex_array::compute::{compare, Operator};
1212
use vortex_array::validity::Validity;
1313
use vortex_buffer::Buffer;
14-
use vortex_dict::dict_encode;
14+
use vortex_dict::builders::dict_encode;
1515

1616
fn gen_primitive_dict(len: usize, uniqueness: f64) -> PrimitiveArray {
1717
let mut rng = thread_rng();

encodings/dict/benches/dict_compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use vortex_array::array::{PrimitiveArray, VarBinArray, VarBinViewArray};
99
use vortex_array::validity::Validity;
1010
use vortex_array::IntoCanonical;
1111
use vortex_buffer::Buffer;
12-
use vortex_dict::dict_encode;
12+
use vortex_dict::builders::dict_encode;
1313

1414
pub fn gen_primitive_dict(len: usize, uniqueness: f64) -> PrimitiveArray {
1515
let mut rng = thread_rng();
-41 Bytes
Binary file not shown.

encodings/dict/src/array.rs

Lines changed: 30 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_array::{
1111
encoding_ids, impl_encoding, Array, Canonical, IntoArray, IntoArrayVariant, IntoCanonical,
1212
SerdeMetadata,
1313
};
14-
use vortex_dtype::{match_each_integer_ptype, DType, Nullability, PType};
14+
use vortex_dtype::{match_each_integer_ptype, DType, PType};
1515
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
1616
use vortex_mask::{AllOr, Mask};
1717

@@ -22,85 +22,41 @@ impl_encoding!(
2222
SerdeMetadata<DictMetadata>
2323
);
2424

25-
#[derive(
26-
Copy,
27-
Clone,
28-
Debug,
29-
Serialize,
30-
Deserialize,
31-
rkyv::Archive,
32-
rkyv::Portable,
33-
rkyv::Serialize,
34-
rkyv::Deserialize,
35-
rkyv::bytecheck::CheckBytes,
36-
)]
37-
#[rkyv(as = DictNullability)]
38-
#[bytecheck(crate = rkyv::bytecheck)]
39-
#[repr(u8)]
40-
enum DictNullability {
41-
NonNullable,
42-
NullableCodes,
43-
NullableValues,
44-
BothNullable,
45-
}
46-
47-
impl DictNullability {
48-
fn from_dtypes(codes_dtype: &DType, values_dtype: &DType) -> Self {
49-
match (codes_dtype.is_nullable(), values_dtype.is_nullable()) {
50-
(true, true) => Self::BothNullable,
51-
(true, false) => Self::NullableCodes,
52-
(false, true) => Self::NullableValues,
53-
(false, false) => Self::NonNullable,
54-
}
55-
}
56-
57-
fn codes_nullability(&self) -> Nullability {
58-
match self {
59-
DictNullability::NonNullable => Nullability::NonNullable,
60-
DictNullability::NullableCodes => Nullability::Nullable,
61-
DictNullability::NullableValues => Nullability::NonNullable,
62-
DictNullability::BothNullable => Nullability::Nullable,
63-
}
64-
}
65-
66-
fn values_nullability(&self) -> Nullability {
67-
match self {
68-
DictNullability::NonNullable => Nullability::NonNullable,
69-
DictNullability::NullableCodes => Nullability::NonNullable,
70-
DictNullability::NullableValues => Nullability::Nullable,
71-
DictNullability::BothNullable => Nullability::Nullable,
72-
}
73-
}
74-
}
75-
7625
#[derive(Debug, Clone, Serialize, Deserialize)]
7726
pub struct DictMetadata {
7827
codes_ptype: PType,
7928
values_len: usize, // TODO(ngates): make this a u32
80-
dict_nullability: DictNullability,
8129
}
8230

8331
impl DictArray {
84-
pub fn try_new(codes: Array, values: Array) -> VortexResult<Self> {
32+
pub fn try_new(mut codes: Array, values: Array) -> VortexResult<Self> {
8533
if !codes.dtype().is_unsigned_int() {
8634
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
8735
}
8836

89-
let dtype = if codes.dtype().is_nullable() {
90-
values.dtype().as_nullable()
37+
let dtype = values.dtype();
38+
if dtype.is_nullable() {
39+
// If the values are nullable, we force codes to be nullable as well.
40+
codes = try_cast(&codes, &codes.dtype().as_nullable())?;
9141
} else {
92-
values.dtype().clone()
93-
};
94-
let dict_nullability = DictNullability::from_dtypes(codes.dtype(), values.dtype());
42+
// If the values are non-nullable, we assert the codes are non-nullable as well.
43+
if codes.dtype().is_nullable() {
44+
vortex_bail!("Cannot have nullable codes for non-nullable dict array");
45+
}
46+
}
47+
assert_eq!(
48+
codes.dtype().nullability(),
49+
values.dtype().nullability(),
50+
"Mismatched nullability between codes and values"
51+
);
9552

9653
Self::try_from_parts(
97-
dtype,
54+
dtype.clone(),
9855
codes.len(),
9956
SerdeMetadata(DictMetadata {
10057
codes_ptype: PType::try_from(codes.dtype())
10158
.vortex_expect("codes dtype must be uint"),
10259
values_len: values.len(),
103-
dict_nullability,
10460
}),
10561
None,
10662
Some([codes, values].into()),
@@ -113,10 +69,7 @@ impl DictArray {
11369
self.as_ref()
11470
.child(
11571
0,
116-
&DType::Primitive(
117-
self.metadata().codes_ptype,
118-
self.metadata().dict_nullability.codes_nullability(),
119-
),
72+
&DType::Primitive(self.metadata().codes_ptype, self.dtype().nullability()),
12073
self.len(),
12174
)
12275
.vortex_expect("DictArray is missing its codes child array")
@@ -125,13 +78,7 @@ impl DictArray {
12578
#[inline]
12679
pub fn values(&self) -> Array {
12780
self.as_ref()
128-
.child(
129-
1,
130-
&self
131-
.dtype()
132-
.with_nullability(self.metadata().dict_nullability.values_nullability()),
133-
self.metadata().values_len,
134-
)
81+
.child(1, self.dtype(), self.metadata().values_len)
13582
.vortex_expect("DictArray is missing its values child array")
13683
}
13784
}
@@ -147,10 +94,10 @@ impl CanonicalVTable<DictArray> for DictEncoding {
14794
// copies of the view pointers.
14895
DType::Utf8(_) | DType::Binary(_) => {
14996
let canonical_values: Array = array.values().into_canonical()?.into_array();
150-
try_cast(take(canonical_values, array.codes())?, array.dtype())?.into_canonical()
97+
take(canonical_values, array.codes())?.into_canonical()
15198
}
15299
// Non-string case: take and then canonicalize
153-
_ => try_cast(take(array.values(), array.codes())?, array.dtype())?.into_canonical(),
100+
_ => take(array.values(), array.codes())?.into_canonical(),
154101
}
155102
}
156103
}
@@ -182,6 +129,14 @@ impl ValidityVTable<DictArray> for DictEncoding {
182129
Ok(array.codes().all_valid()? && array.values().all_valid()?)
183130
}
184131

132+
fn all_invalid(&self, array: &DictArray) -> VortexResult<bool> {
133+
if !array.dtype().is_nullable() {
134+
return Ok(false);
135+
}
136+
137+
Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
138+
}
139+
185140
fn validity_mask(&self, array: &DictArray) -> VortexResult<Mask> {
186141
let codes_validity = array.codes().validity_mask()?;
187142
match codes_validity.boolean_buffer() {
@@ -231,7 +186,6 @@ mod test {
231186
use vortex_error::vortex_panic;
232187
use vortex_mask::AllOr;
233188

234-
use crate::array::DictNullability::BothNullable;
235189
use crate::{DictArray, DictMetadata};
236190

237191
#[cfg_attr(miri, ignore)]
@@ -242,7 +196,6 @@ mod test {
242196
SerdeMetadata(DictMetadata {
243197
codes_ptype: PType::U64,
244198
values_len: usize::MAX,
245-
dict_nullability: BothNullable,
246199
}),
247200
);
248201
}
@@ -255,7 +208,7 @@ mod test {
255208
Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
256209
)
257210
.into_array(),
258-
buffer![3, 6, 9].into_array(),
211+
PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
259212
)
260213
.unwrap();
261214
let mask = dict.validity_mask().unwrap();

0 commit comments

Comments
 (0)