Skip to content

Commit 1b89b2f

Browse files
authored
fix: Decimal ToArrow conversion (#3367)
1 parent 9fd706b commit 1b89b2f

File tree

18 files changed

+634
-122
lines changed

18 files changed

+634
-122
lines changed

.github/workflows/labels.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
contains(github.event.pull_request.labels.*.name, 'chore') == false &&
1414
contains(github.event.pull_request.labels.*.name, 'bug') == false &&
1515
contains(github.event.pull_request.labels.*.name, 'feature') == false &&
16+
contains(github.event.pull_request.labels.*.name, 'fix') == false &&
1617
contains(github.event.pull_request.labels.*.name, 'performance') == false &&
1718
contains(github.event.pull_request.labels.*.name, 'break') == false &&
1819
contains(github.event.pull_request.labels.*.name, 'wire-break') == false

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
mod compute;
22
mod serde;
33

4-
use std::iter;
5-
64
use itertools::Itertools;
75
use vortex_array::arrays::DecimalArray;
86
use vortex_array::stats::{ArrayStats, StatsSetRef};
@@ -82,27 +80,6 @@ impl DecimalBytePartsArray {
8280
vortex_bail!("decimal bytes parts 2nd to 4th must be non-nullable u64 primitive typed")
8381
}
8482

85-
let primitive_bit_width = iter::once(&msp)
86-
.chain(&lower_parts)
87-
.map(|a| {
88-
PType::try_from(a.dtype())
89-
.vortex_expect("already checked")
90-
.bit_width()
91-
})
92-
.sum();
93-
94-
if decimal_dtype.required_bit_width() > primitive_bit_width {
95-
vortex_bail!(
96-
"cannot represent a decimal {decimal_dtype} as primitive parts {:?}, decimal bit width {}, primitive bit width {}",
97-
iter::once(&msp)
98-
.chain(&lower_parts)
99-
.map(|a| a.dtype())
100-
.collect_vec(),
101-
decimal_dtype.required_bit_width(),
102-
primitive_bit_width
103-
)
104-
}
105-
10683
let nullable = msp.dtype().nullability();
10784
Ok(Self {
10885
msp,

java/vortex-jni/src/main/java/dev/vortex/api/Array.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package dev.vortex.api;
1717

18+
import java.math.BigDecimal;
1819
import org.apache.arrow.memory.BufferAllocator;
1920
import org.apache.arrow.vector.VectorSchemaRoot;
2021

@@ -52,6 +53,8 @@ public interface Array extends AutoCloseable {
5253

5354
double getDouble(int index);
5455

56+
BigDecimal getBigDecimal(int index);
57+
5558
String getUTF8(int index);
5659

5760
void getUTF8_ptr_len(int index, long[] ptr, int[] len);

java/vortex-jni/src/main/java/dev/vortex/jni/JNIArray.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.google.common.base.Preconditions;
1919
import dev.vortex.api.Array;
2020
import dev.vortex.api.DType;
21+
import java.math.BigDecimal;
2122
import java.util.OptionalLong;
2223
import org.apache.arrow.c.ArrowArray;
2324
import org.apache.arrow.c.ArrowSchema;
@@ -130,6 +131,11 @@ public double getDouble(int index) {
130131
return NativeArrayMethods.getDouble(pointer.getAsLong(), index);
131132
}
132133

134+
@Override
135+
public BigDecimal getBigDecimal(int index) {
136+
return NativeArrayMethods.getBigDecimal(pointer.getAsLong(), index);
137+
}
138+
133139
@Override
134140
public String getUTF8(int index) {
135141
return NativeArrayMethods.getUTF8(pointer.getAsLong(), index);

java/vortex-jni/src/main/java/dev/vortex/jni/NativeArrayMethods.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package dev.vortex.jni;
1717

18+
import java.math.BigDecimal;
19+
1820
public final class NativeArrayMethods {
1921
static {
2022
NativeLoader.loadJni();
@@ -58,6 +60,8 @@ private NativeArrayMethods() {}
5860

5961
public static native double getDouble(long pointer, int index);
6062

63+
public static native BigDecimal getBigDecimal(long pointer, int index);
64+
6165
public static native String getUTF8(long pointer, int index);
6266

6367
/**

vortex-array/src/arrays/constant/canonical.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::arrays::constant::ConstantArray;
1111
use crate::arrays::primitive::PrimitiveArray;
1212
use crate::arrays::{
1313
BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14-
StructArray, VarBinViewArray, precision_to_storage_size,
14+
StructArray, VarBinViewArray, smallest_storage_type,
1515
};
1616
use crate::builders::{ArrayBuilderExt, builder_with_capacity};
1717
use crate::validity::Validity;
@@ -57,7 +57,7 @@ impl CanonicalVTable<ConstantVTable> for ConstantVTable {
5757
})
5858
}
5959
DType::Decimal(decimal_type, ..) => {
60-
let size = precision_to_storage_size(decimal_type);
60+
let size = smallest_storage_type(decimal_type);
6161
let decimal = scalar.as_decimal();
6262
let Some(value) = decimal.decimal_value() else {
6363
let all_null = match_each_decimal_value_type!(size, |$D| {

vortex-array/src/arrays/decimal/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ impl VTable for DecimalVTable {
4444
#[derive(Clone, Debug)]
4545
pub struct DecimalEncoding;
4646

47-
/// Maps a decimal precision into the small type that can represent it.
48-
pub fn precision_to_storage_size(decimal_dtype: &DecimalDType) -> DecimalValueType {
47+
/// Maps a decimal precision into the smallest type that can represent it.
48+
pub fn smallest_storage_type(decimal_dtype: &DecimalDType) -> DecimalValueType {
4949
match decimal_dtype.precision() {
5050
1..=2 => DecimalValueType::I8,
5151
3..=4 => DecimalValueType::I16,

vortex-array/src/arrow/compute/to_arrow/canonical.rs

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ use arrow_array::{
1414
use arrow_buffer::{ScalarBuffer, i256};
1515
use arrow_schema::{DataType, Field, FieldRef, Fields};
1616
use itertools::Itertools;
17-
use num_traits::AsPrimitive;
17+
use num_traits::{AsPrimitive, ToPrimitive};
1818
use vortex_buffer::Buffer;
1919
use vortex_dtype::{DType, NativePType, PType};
20-
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
20+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2121
use vortex_scalar::DecimalValueType;
2222

2323
use crate::arrays::{
@@ -104,8 +104,34 @@ impl Kernel for ToArrowCanonical {
104104
{
105105
to_arrow_primitive::<Float64Type>(array)
106106
}
107-
(Canonical::Decimal(array), DataType::Decimal128(..)) => to_arrow_decimal128(array),
108-
(Canonical::Decimal(array), DataType::Decimal256(..)) => to_arrow_decimal256(array),
107+
(Canonical::Decimal(array), DataType::Decimal128(precision, scale)) => {
108+
if array.decimal_dtype().precision() != *precision
109+
|| array.decimal_dtype().scale() != *scale
110+
{
111+
vortex_bail!(
112+
"ToArrowCanonical: target precision/scale {}/{} does not match array precision/scale {}/{}",
113+
precision,
114+
scale,
115+
array.decimal_dtype().precision(),
116+
array.decimal_dtype().scale()
117+
);
118+
}
119+
to_arrow_decimal128(array)
120+
}
121+
(Canonical::Decimal(array), DataType::Decimal256(precision, scale)) => {
122+
if array.decimal_dtype().precision() != *precision
123+
|| array.decimal_dtype().scale() != *scale
124+
{
125+
vortex_bail!(
126+
"ToArrowCanonical: target precision/scale {}/{} does not match array precision/scale {}/{}",
127+
precision,
128+
scale,
129+
array.decimal_dtype().precision(),
130+
array.decimal_dtype().scale()
131+
);
132+
}
133+
to_arrow_decimal256(array)
134+
}
109135
(Canonical::Struct(array), DataType::Struct(fields)) => {
110136
to_arrow_struct(array, fields.as_ref())
111137
}
@@ -188,9 +214,14 @@ fn to_arrow_decimal128(array: DecimalArray) -> VortexResult<ArrowArrayRef> {
188214
DecimalValueType::I32 => array.buffer::<i32>().into_iter().map(|x| x.as_()).collect(),
189215
DecimalValueType::I64 => array.buffer::<i64>().into_iter().map(|x| x.as_()).collect(),
190216
DecimalValueType::I128 => array.buffer::<i128>(),
191-
DecimalValueType::I256 => {
192-
vortex_bail!("i256 decimals cannot be converted to Arrow i128 decimal")
193-
}
217+
DecimalValueType::I256 => array
218+
.buffer::<vortex_scalar::i256>()
219+
.into_iter()
220+
.map(|x| {
221+
x.to_i128()
222+
.ok_or_else(|| vortex_err!("i256 to i128 narrowing cannot be done safely"))
223+
})
224+
.try_collect()?,
194225
_ => vortex_bail!("unknown value type {:?}", array.values_type()),
195226
};
196227
Ok(Arc::new(
@@ -206,10 +237,14 @@ fn to_arrow_decimal256(array: DecimalArray) -> VortexResult<ArrowArrayRef> {
206237
let null_buffer = array.validity_mask()?.to_null_buffer();
207238
let buffer: Buffer<i256> = match array.values_type() {
208239
DecimalValueType::I8 => array.buffer::<i8>().into_iter().map(|x| x.as_()).collect(),
209-
DecimalValueType::I16 => array.buffer::<i8>().into_iter().map(|x| x.as_()).collect(),
210-
DecimalValueType::I32 => array.buffer::<i8>().into_iter().map(|x| x.as_()).collect(),
211-
DecimalValueType::I64 => array.buffer::<i8>().into_iter().map(|x| x.as_()).collect(),
212-
DecimalValueType::I128 => array.buffer::<i8>().into_iter().map(|x| x.as_()).collect(),
240+
DecimalValueType::I16 => array.buffer::<i16>().into_iter().map(|x| x.as_()).collect(),
241+
DecimalValueType::I32 => array.buffer::<i32>().into_iter().map(|x| x.as_()).collect(),
242+
DecimalValueType::I64 => array.buffer::<i64>().into_iter().map(|x| x.as_()).collect(),
243+
DecimalValueType::I128 => array
244+
.buffer::<i128>()
245+
.into_iter()
246+
.map(|x| vortex_scalar::i256::from_i128(x).into())
247+
.collect(),
213248
DecimalValueType::I256 => Buffer::<i256>::from_byte_buffer(array.byte_buffer()),
214249
_ => vortex_bail!("unknown type {:?}", array.values_type()),
215250
};
@@ -334,15 +369,19 @@ where
334369

335370
#[cfg(test)]
336371
mod tests {
337-
use arrow_array::Decimal128Array;
372+
use arrow_array::{Array, Decimal128Array, Decimal256Array};
373+
use arrow_buffer::i256;
338374
use arrow_schema::{DataType, Field};
375+
use rstest::rstest;
339376
use vortex_buffer::buffer;
340377
use vortex_dtype::{DecimalDType, FieldNames};
378+
use vortex_scalar::NativeDecimalType;
341379

342380
use crate::IntoArray;
343381
use crate::arrays::{DecimalArray, PrimitiveArray, StructArray};
344382
use crate::arrow::IntoArrowArray;
345383
use crate::arrow::compute::to_arrow;
384+
use crate::builders::{ArrayBuilder, DecimalBuilder};
346385
use crate::validity::Validity;
347386

348387
#[test]
@@ -398,4 +437,54 @@ mod tests {
398437

399438
assert!(struct_a.into_array().into_arrow(&arrow_dt).is_err());
400439
}
440+
441+
#[rstest]
442+
#[case(0i8)]
443+
#[case(0i16)]
444+
#[case(0i32)]
445+
#[case(0i64)]
446+
#[case(0i128)]
447+
#[case(vortex_scalar::i256::ZERO)]
448+
fn to_arrow_decimal128<T: NativeDecimalType>(#[case] _decimal_type: T) {
449+
let mut decimal = DecimalBuilder::new::<T>(2, 1, false.into());
450+
decimal.append_value(10);
451+
decimal.append_value(11);
452+
decimal.append_value(12);
453+
454+
let decimal = decimal.finish();
455+
456+
let arrow_array = decimal.into_arrow(&DataType::Decimal128(2, 1)).unwrap();
457+
let arrow_decimal = arrow_array
458+
.as_any()
459+
.downcast_ref::<Decimal128Array>()
460+
.unwrap();
461+
assert_eq!(arrow_decimal.value(0), 10);
462+
assert_eq!(arrow_decimal.value(1), 11);
463+
assert_eq!(arrow_decimal.value(2), 12);
464+
}
465+
466+
#[rstest]
467+
#[case(0i8)]
468+
#[case(0i16)]
469+
#[case(0i32)]
470+
#[case(0i64)]
471+
#[case(0i128)]
472+
#[case(vortex_scalar::i256::ZERO)]
473+
fn to_arrow_decimal256<T: NativeDecimalType>(#[case] _decimal_type: T) {
474+
let mut decimal = DecimalBuilder::new::<T>(2, 1, false.into());
475+
decimal.append_value(10);
476+
decimal.append_value(11);
477+
decimal.append_value(12);
478+
479+
let decimal = decimal.finish();
480+
481+
let arrow_array = decimal.into_arrow(&DataType::Decimal256(2, 1)).unwrap();
482+
let arrow_decimal = arrow_array
483+
.as_any()
484+
.downcast_ref::<Decimal256Array>()
485+
.unwrap();
486+
assert_eq!(arrow_decimal.value(0), i256::from_i128(10));
487+
assert_eq!(arrow_decimal.value(1), i256::from_i128(11));
488+
assert_eq!(arrow_decimal.value(2), i256::from_i128(12));
489+
}
401490
}

vortex-array/src/arrow/convert.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,24 @@ impl_from_arrow_primitive!(Float32Type);
9898
impl_from_arrow_primitive!(Float64Type);
9999

100100
impl FromArrowArray<&ArrowPrimitiveArray<Decimal128Type>> for ArrayRef {
101-
fn from_arrow(array: &ArrowPrimitiveArray<Decimal128Type>, _nullable: bool) -> Self {
101+
fn from_arrow(array: &ArrowPrimitiveArray<Decimal128Type>, nullable: bool) -> Self {
102102
let decimal_type = DecimalDType::new(array.precision(), array.scale());
103103
let buffer = Buffer::from_arrow_scalar_buffer(array.values().clone());
104-
let validity = nulls(array.nulls(), false);
104+
let validity = nulls(array.nulls(), nullable);
105105
DecimalArray::new(buffer, decimal_type, validity).into_array()
106106
}
107107
}
108108

109109
impl FromArrowArray<&ArrowPrimitiveArray<Decimal256Type>> for ArrayRef {
110-
fn from_arrow(array: &ArrowPrimitiveArray<Decimal256Type>, _nullable: bool) -> Self {
110+
fn from_arrow(array: &ArrowPrimitiveArray<Decimal256Type>, nullable: bool) -> Self {
111111
let decimal_type = DecimalDType::new(array.precision(), array.scale());
112112
let buffer = Buffer::from_arrow_scalar_buffer(array.values().clone());
113113
// SAFETY: Our i256 implementation has the same bit-pattern representation of the
114114
// arrow_buffer::i256 type. It is safe to treat values held inside the buffer as values
115115
// of either type.
116116
let buffer =
117117
unsafe { std::mem::transmute::<Buffer<arrow_buffer::i256>, Buffer<i256>>(buffer) };
118-
let validity = nulls(array.nulls(), false);
118+
let validity = nulls(array.nulls(), nullable);
119119
DecimalArray::new(buffer, decimal_type, validity).into_array()
120120
}
121121
}

0 commit comments

Comments
 (0)