Skip to content

Commit 91b6bee

Browse files
authored
Feature: add Decimal::bind (#5210)
Basically identical to `Primitive` Signed-off-by: Connor Tsui <[email protected]>
1 parent 91f20ae commit 91b6bee

File tree

5 files changed

+112
-13
lines changed

5 files changed

+112
-13
lines changed

vortex-array/src/arrays/decimal/vtable/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::{EncodingId, EncodingRef, vtable};
88
mod array;
99
mod canonical;
1010
mod operations;
11+
mod operator;
1112
mod serde;
1213
mod validity;
1314
mod visitor;
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_compute::filter::Filter;
5+
use vortex_dtype::PrecisionScale;
6+
use vortex_error::VortexResult;
7+
use vortex_scalar::match_each_decimal_value_type;
8+
use vortex_vector::decimal::DVector;
9+
10+
use crate::arrays::{DecimalArray, DecimalVTable, MaskedVTable};
11+
use crate::execution::{BatchKernelRef, BindCtx, kernel};
12+
use crate::vtable::{OperatorVTable, ValidityHelper};
13+
use crate::{ArrayRef, IntoArray};
14+
15+
impl OperatorVTable<DecimalVTable> for DecimalVTable {
16+
fn bind(
17+
array: &DecimalArray,
18+
selection: Option<&ArrayRef>,
19+
ctx: &mut dyn BindCtx,
20+
) -> VortexResult<BatchKernelRef> {
21+
let mask = ctx.bind_selection(array.len(), selection)?;
22+
let validity = ctx.bind_validity(array.validity(), array.len(), selection)?;
23+
24+
match_each_decimal_value_type!(array.values_type(), |D| {
25+
let elements = array.buffer::<D>();
26+
let ps = PrecisionScale::<D>::try_from(&array.decimal_dtype())?;
27+
28+
Ok(kernel(move || {
29+
let mask = mask.execute()?;
30+
let validity = validity.execute()?;
31+
32+
// Note that validity already has the mask applied so we only need to apply it to
33+
// the elements.
34+
let elements = elements.filter(&mask);
35+
36+
Ok(DVector::<D>::try_new(ps, elements, validity)?.into())
37+
}))
38+
})
39+
}
40+
41+
fn reduce_parent(
42+
array: &DecimalArray,
43+
parent: &ArrayRef,
44+
_child_idx: usize,
45+
) -> VortexResult<Option<ArrayRef>> {
46+
// Push-down masking of `validity` from the parent `MaskedArray`.
47+
if let Some(masked) = parent.as_opt::<MaskedVTable>() {
48+
let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
49+
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
50+
// were upheld are still upheld.
51+
unsafe {
52+
DecimalArray::new_unchecked(
53+
array.buffer::<D>(),
54+
array.decimal_dtype(),
55+
array.validity().clone().and(masked.validity().clone()),
56+
)
57+
}
58+
.into_array()
59+
});
60+
61+
return Ok(Some(masked_array));
62+
}
63+
64+
Ok(None)
65+
}
66+
}

vortex-array/src/arrays/primitive/vtable/operator.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_buffer::Buffer;
45
use vortex_compute::filter::Filter;
56
use vortex_dtype::match_each_native_ptype;
67
use vortex_error::VortexResult;
@@ -20,8 +21,8 @@ impl OperatorVTable<PrimitiveVTable> for PrimitiveVTable {
2021
let mask = ctx.bind_selection(array.len(), selection)?;
2122
let validity = ctx.bind_validity(array.validity(), array.len(), selection)?;
2223

23-
match_each_native_ptype!(array.ptype(), |T| {
24-
let elements = array.buffer::<T>();
24+
match_each_native_ptype!(array.ptype(), |P| {
25+
let elements = array.buffer::<P>();
2526
Ok(kernel(move || {
2627
let mask = mask.execute()?;
2728
let validity = validity.execute()?;
@@ -30,7 +31,7 @@ impl OperatorVTable<PrimitiveVTable> for PrimitiveVTable {
3031
// the elements.
3132
let elements = elements.filter(&mask);
3233

33-
Ok(PVector::try_new(elements, validity)?.into())
34+
Ok(PVector::<P>::try_new(elements, validity)?.into())
3435
}))
3536
})
3637
}
@@ -40,16 +41,21 @@ impl OperatorVTable<PrimitiveVTable> for PrimitiveVTable {
4041
parent: &ArrayRef,
4142
_child_idx: usize,
4243
) -> VortexResult<Option<ArrayRef>> {
43-
// Push-down masking of validity from parent MaskedVTable.
44+
// Push-down masking of `validity` from the parent `MaskedArray`.
4445
if let Some(masked) = parent.as_opt::<MaskedVTable>() {
45-
return Ok(Some(
46-
PrimitiveArray::from_byte_buffer(
47-
array.byte_buffer().clone(),
48-
array.ptype(),
49-
array.validity().clone().and(masked.validity().clone()),
50-
)
51-
.into_array(),
52-
));
46+
let masked_array = match_each_native_ptype!(array.ptype(), |T| {
47+
// SAFETY: Since we are only flipping some bits in the validity, all invariants that
48+
// were upheld are still upheld.
49+
unsafe {
50+
PrimitiveArray::new_unchecked(
51+
Buffer::<T>::from_byte_buffer(array.byte_buffer().clone()),
52+
array.validity().clone().and(masked.validity().clone()),
53+
)
54+
}
55+
.into_array()
56+
});
57+
58+
return Ok(Some(masked_array));
5359
}
5460

5561
Ok(None)

vortex-vector/src/decimal/vector.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
//! Definition and implementation of [`DecimalVector`].
55
6-
use vortex_dtype::{DecimalTypeDowncast, DecimalTypeUpcast, NativeDecimalType, i256};
6+
use vortex_dtype::{DecimalType, DecimalTypeDowncast, DecimalTypeUpcast, NativeDecimalType, i256};
77
use vortex_error::vortex_panic;
88
use vortex_mask::Mask;
99

@@ -27,6 +27,20 @@ pub enum DecimalVector {
2727
D256(DVector<i256>),
2828
}
2929

30+
impl DecimalVector {
31+
/// Returns the [`DecimalType`] of the decimal vector.
32+
pub fn decimal_type(&self) -> DecimalType {
33+
match self {
34+
Self::D8(_) => DecimalType::I8,
35+
Self::D16(_) => DecimalType::I16,
36+
Self::D32(_) => DecimalType::I32,
37+
Self::D64(_) => DecimalType::I64,
38+
Self::D128(_) => DecimalType::I128,
39+
Self::D256(_) => DecimalType::I256,
40+
}
41+
}
42+
}
43+
3044
impl VectorOps for DecimalVector {
3145
type Mutable = DecimalVectorMut;
3246

vortex-vector/src/decimal/vector_mut.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ pub enum DecimalVectorMut {
3131
}
3232

3333
impl DecimalVectorMut {
34+
/// Returns the [`DecimalType`] of the decimal vector.
35+
pub fn decimal_type(&self) -> DecimalType {
36+
match self {
37+
Self::D8(_) => DecimalType::I8,
38+
Self::D16(_) => DecimalType::I16,
39+
Self::D32(_) => DecimalType::I32,
40+
Self::D64(_) => DecimalType::I64,
41+
Self::D128(_) => DecimalType::I128,
42+
Self::D256(_) => DecimalType::I256,
43+
}
44+
}
45+
3446
/// Create a new mutable decimal vector with the given primitive type and capacity.
3547
pub fn with_capacity(decimal_dtype: &DecimalDType, capacity: usize) -> Self {
3648
let decimal_type = DecimalType::smallest_decimal_value_type(decimal_dtype);

0 commit comments

Comments
 (0)