Skip to content

Commit 33088c2

Browse files
authored
feat: ArrayOperations infallible, eager validation + new_unchecked (#4177)
ArrayOperations currently return VortexResult<>, but they really should just be infallible. A failed array op is generally indicative of programmer or encoding error. There's really nothing interesting we can do to handle an out-of-bounds slice() or scalar_at. There's a lot that falls out of this, like fixing a bunch of tests, tweaking our scalar value casting to return Option instead of Result, etc. --------- Signed-off-by: Andrew Duffy <[email protected]>
1 parent 431a8f2 commit 33088c2

File tree

237 files changed

+2995
-2452
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

237 files changed

+2995
-2452
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bench-vortex/src/bin/random_access.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ fn random_access(
9494
.vortex_expect("could not get DOLocationID");
9595
for (idx, loc) in [90i32, 249, 230, 79, 239, 236].iter().enumerate() {
9696
assert_eq!(
97-
pu_location_id.scalar_at(idx).vortex_expect("scalar_at"),
97+
pu_location_id.scalar_at(idx),
9898
Scalar::primitive(*loc, NonNullable)
9999
);
100100
}
101101
for (idx, loc) in [164i32, 231, 25, 224, 243, 239].iter().enumerate() {
102102
assert_eq!(
103-
do_location_id.scalar_at(idx).vortex_expect("scalar_at"),
103+
do_location_id.scalar_at(idx),
104104
Scalar::primitive(*loc, NonNullable)
105105
);
106106
}

encodings/alp/src/alp/array.rs

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ use vortex_array::vtable::{
1010
};
1111
use vortex_array::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, vtable};
1212
use vortex_dtype::{DType, PType};
13-
use vortex_error::{VortexResult, vortex_bail};
13+
use vortex_error::{VortexExpect, VortexResult, vortex_ensure};
1414

15+
use crate::ALPFloat;
1516
use crate::alp::{Exponents, decompress};
1617

1718
vtable!(ALP);
@@ -51,17 +52,150 @@ pub struct ALPArray {
5152
pub struct ALPEncoding;
5253

5354
impl ALPArray {
54-
// TODO(ngates): remove try_new and panic on wrong DType?
55+
fn validate(
56+
encoded: &dyn Array,
57+
exponents: Exponents,
58+
patches: Option<&Patches>,
59+
) -> VortexResult<()> {
60+
vortex_ensure!(
61+
matches!(
62+
encoded.dtype(),
63+
DType::Primitive(PType::I32 | PType::I64, _)
64+
),
65+
"ALP encoded ints have invalid DType {}",
66+
encoded.dtype(),
67+
);
68+
69+
// Validate exponents are in-bounds for the float, and that patches have the proper
70+
// length and type.
71+
let Exponents { e, f } = exponents;
72+
match encoded.dtype().as_ptype() {
73+
PType::I32 => {
74+
vortex_ensure!(exponents.e <= f32::MAX_EXPONENT, "e out of bounds: {e}");
75+
vortex_ensure!(exponents.f <= f32::MAX_EXPONENT, "f out of bounds: {f}");
76+
if let Some(patches) = patches {
77+
Self::validate_patches::<f32>(patches, encoded)?;
78+
}
79+
}
80+
PType::I64 => {
81+
vortex_ensure!(e <= f64::MAX_EXPONENT, "e out of bounds: {e}");
82+
vortex_ensure!(f <= f64::MAX_EXPONENT, "f out of bounds: {f}");
83+
84+
if let Some(patches) = patches {
85+
Self::validate_patches::<f64>(patches, encoded)?;
86+
}
87+
}
88+
_ => unreachable!(),
89+
}
90+
91+
// Validate patches
92+
if let Some(patches) = patches {
93+
vortex_ensure!(
94+
patches.array_len() == encoded.len(),
95+
"patches array_len != encoded len: {} != {}",
96+
patches.array_len(),
97+
encoded.len()
98+
);
99+
100+
// Verify that the patches DType are of the proper DType.
101+
}
102+
103+
Ok(())
104+
}
105+
106+
/// Validate that any patches provided are valid for the ALPArray.
107+
fn validate_patches<T: ALPFloat>(patches: &Patches, encoded: &dyn Array) -> VortexResult<()> {
108+
vortex_ensure!(
109+
patches.array_len() == encoded.len(),
110+
"patches array_len != encoded len: {} != {}",
111+
patches.array_len(),
112+
encoded.len()
113+
);
114+
115+
let expected_type = DType::Primitive(T::PTYPE, encoded.dtype().nullability());
116+
vortex_ensure!(
117+
patches.dtype() == &expected_type,
118+
"Expected patches type {expected_type}, actual {}",
119+
patches.dtype(),
120+
);
121+
122+
Ok(())
123+
}
124+
}
125+
126+
impl ALPArray {
127+
/// Build a new `ALPArray` from components, panicking on validation failure.
128+
///
129+
/// See [`ALPArray::try_new`] for reference on preconditions that must pass before
130+
/// calling this method.
131+
pub fn new(encoded: ArrayRef, exponents: Exponents, patches: Option<Patches>) -> Self {
132+
Self::try_new(encoded, exponents, patches).vortex_expect("ALPArray new")
133+
}
134+
135+
/// Build a new `ALPArray` from components:
136+
///
137+
/// * `encoded` contains the ALP-encoded ints. Any null values are replaced with placeholders
138+
/// * `exponents` are the ALP exponents, valid range depends on the data type
139+
/// * `patches` are any patch values that don't cleanly encode using the ALP conversion function
140+
///
141+
/// This method validates the inputs and will return an error if any validation fails.
142+
///
143+
/// # Validation
144+
///
145+
/// * The `encoded` array must be either `i32` or `i64`
146+
/// * If `i32`, any `patches` must have DType `f32` with same nullability
147+
/// * If `i64`, then `patches`must have DType `f64` with same nullability
148+
/// * `exponents` must be in the valid range depending on if the ALPArray is of type `f32` or
149+
/// `f64`.
150+
/// * `patches` must have an `array_len` equal to the length of `encoded`
151+
///
152+
/// Any failure of these preconditions will result in an error being returned.
153+
///
154+
/// # Examples
155+
///
156+
/// ```
157+
/// # use vortex_alp::{ALPArray, Exponents};
158+
/// # use vortex_array::IntoArray;
159+
/// # use vortex_buffer::buffer;
160+
///
161+
/// // Returns error because buffer has wrong PType.
162+
/// let result = ALPArray::try_new(
163+
/// buffer![1i8].into_array(),
164+
/// Exponents { e: 1, f: 1 },
165+
/// None
166+
/// );
167+
/// assert!(result.is_err());
168+
///
169+
/// // Returns error because Exponents are out of bounds for f32
170+
/// let result = ALPArray::try_new(
171+
/// buffer![1i32, 2i32].into_array(),
172+
/// Exponents { e: 100, f: 100 },
173+
/// None
174+
/// );
175+
/// assert!(result.is_err());
176+
///
177+
/// // Success!
178+
/// let value = ALPArray::try_new(
179+
/// buffer![0i32].into_array(),
180+
/// Exponents { e: 1, f: 1 },
181+
/// None
182+
/// ).unwrap();
183+
///
184+
/// assert_eq!(value.scalar_at(0), 0f32.into());
185+
/// ```
55186
pub fn try_new(
56187
encoded: ArrayRef,
57188
exponents: Exponents,
58189
patches: Option<Patches>,
59190
) -> VortexResult<Self> {
191+
Self::validate(&encoded, exponents, patches.as_ref())?;
192+
60193
let dtype = match encoded.dtype() {
61194
DType::Primitive(PType::I32, nullability) => DType::Primitive(PType::F32, *nullability),
62195
DType::Primitive(PType::I64, nullability) => DType::Primitive(PType::F64, *nullability),
63-
d => vortex_bail!(MismatchedTypes: "int32 or int64", d),
196+
_ => unreachable!(),
64197
};
198+
65199
Ok(Self {
66200
dtype,
67201
encoded,
@@ -71,6 +205,25 @@ impl ALPArray {
71205
})
72206
}
73207

208+
/// Build a new `ALPArray` from components without validation.
209+
///
210+
/// See [`ALPArray::try_new`] for information about the preconditions that should be checked
211+
/// **before** calling this method.
212+
pub(crate) unsafe fn new_unchecked(
213+
encoded: ArrayRef,
214+
exponents: Exponents,
215+
patches: Option<Patches>,
216+
dtype: DType,
217+
) -> Self {
218+
Self {
219+
dtype,
220+
encoded,
221+
exponents,
222+
patches,
223+
stats_set: Default::default(),
224+
}
225+
}
226+
74227
pub fn ptype(&self) -> PType {
75228
self.dtype.as_ptype()
76229
}

encodings/alp/src/alp/compress.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,15 @@ pub fn alp_encode(parray: &PrimitiveArray, exponents: Option<Exponents>) -> Vort
4343
_ => vortex_bail!("ALP can only encode f32 and f64"),
4444
};
4545

46-
ALPArray::try_new(encoded, exponents, patches)
46+
// SAFETY: alp_encode_components_typed must return well-formed components
47+
unsafe {
48+
Ok(ALPArray::new_unchecked(
49+
encoded,
50+
exponents,
51+
patches,
52+
parray.dtype().clone(),
53+
))
54+
}
4755
}
4856

4957
#[allow(clippy::cast_possible_truncation)]
@@ -65,7 +73,7 @@ where
6573

6674
let validity = values.validity_mask()?;
6775
// exceptional_positions may contain exceptions at invalid positions (which contain garbage
68-
// data). We remove invalid exceptional positions in order to keep the Patches small.
76+
// data). We remove null exceptions in order to keep the Patches small.
6977
let (valid_exceptional_positions, valid_exceptional_values): (Buffer<u64>, Buffer<T>) =
7078
match validity {
7179
Mask::AllTrue(_) => (exceptional_positions, exceptional_values),
@@ -194,10 +202,10 @@ mod tests {
194202
assert_eq!(encoded.exponents(), Exponents { e: 16, f: 13 });
195203

196204
let decoded = decompress(&encoded).unwrap();
197-
assert_eq!(decoded.scalar_at(0).unwrap(), array.scalar_at(0).unwrap());
198-
assert_eq!(decoded.scalar_at(1).unwrap(), array.scalar_at(1).unwrap());
205+
assert_eq!(decoded.scalar_at(0), array.scalar_at(0));
206+
assert_eq!(decoded.scalar_at(1), array.scalar_at(1));
199207
assert!(!decoded.is_valid(2).unwrap());
200-
assert_eq!(decoded.scalar_at(3).unwrap(), array.scalar_at(3).unwrap());
208+
assert_eq!(decoded.scalar_at(3), array.scalar_at(3));
201209
}
202210

203211
#[test]
@@ -216,12 +224,12 @@ mod tests {
216224
assert_eq!(encoded.exponents(), Exponents { e: 16, f: 13 });
217225

218226
for idx in 0..3 {
219-
let s = encoded.scalar_at(idx).unwrap();
227+
let s = encoded.scalar_at(idx);
220228
assert!(s.is_valid());
221229
}
222230

223231
assert!(!encoded.is_valid(4).unwrap());
224-
let s = encoded.scalar_at(4).unwrap();
232+
let s = encoded.scalar_at(4);
225233
assert!(s.is_null());
226234

227235
let _decoded = decompress(&encoded).unwrap();
@@ -249,9 +257,9 @@ mod tests {
249257
decompressed.as_slice::<f64>()
250258
);
251259
assert_eq!(original.validity(), decompressed.validity());
252-
assert_eq!(original.scalar_at(0).unwrap(), Scalar::null_typed::<f64>());
253-
assert_eq!(original.scalar_at(1).unwrap(), Scalar::null_typed::<f64>());
254-
assert_eq!(original.scalar_at(2).unwrap(), Scalar::null_typed::<f64>());
260+
assert_eq!(original.scalar_at(0), Scalar::null_typed::<f64>());
261+
assert_eq!(original.scalar_at(1), Scalar::null_typed::<f64>());
262+
assert_eq!(original.scalar_at(2), Scalar::null_typed::<f64>());
255263
}
256264

257265
#[test]

encodings/alp/src/alp/compute/cast.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@ impl CastKernel for ALPVTable {
2222
.with_nullability(dtype.nullability()),
2323
)?;
2424

25-
Ok(Some(
26-
ALPArray::try_new(new_encoded, array.exponents(), array.patches().cloned())?
25+
// SAFETY: casting nullability doesn't alter the invariants
26+
unsafe {
27+
Ok(Some(
28+
ALPArray::new_unchecked(
29+
new_encoded,
30+
array.exponents(),
31+
array.patches().cloned(),
32+
dtype.clone(),
33+
)
2734
.into_array(),
28-
))
35+
))
36+
}
2937
} else {
3038
Ok(None)
3139
}

encodings/alp/src/alp/compute/filter.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ impl FilterKernel for ALPVTable {
1616
.transpose()?
1717
.flatten();
1818

19-
Ok(
20-
ALPArray::try_new(filter(array.encoded(), mask)?, array.exponents(), patches)?
21-
.to_array(),
22-
)
19+
// SAFETY: filtering the values does not change correctness
20+
unsafe {
21+
Ok(ALPArray::new_unchecked(
22+
filter(array.encoded(), mask)?,
23+
array.exponents(),
24+
patches,
25+
array.dtype().clone(),
26+
)
27+
.to_array())
28+
}
2329
}
2430
}
2531

encodings/alp/src/alp/compute/mask.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl MaskKernel for ALPVTable {
2424
)
2525
})
2626
.transpose()?;
27-
Ok(ALPArray::try_new(masked_encoded, array.exponents(), masked_patches)?.to_array())
27+
Ok(ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array())
2828
}
2929
}
3030

encodings/alp/src/alp/compute/take.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
5-
use vortex_array::{Array, ArrayRef, register_kernel};
5+
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
66
use vortex_error::VortexResult;
77

88
use crate::{ALPArray, ALPVTable};
@@ -23,7 +23,7 @@ impl TakeKernel for ALPVTable {
2323
)
2424
})
2525
.transpose()?;
26-
Ok(ALPArray::try_new(taken_encoded, array.exponents(), taken_patches)?.to_array())
26+
Ok(ALPArray::new(taken_encoded, array.exponents(), taken_patches).into_array())
2727
}
2828
}
2929

0 commit comments

Comments
 (0)