Skip to content

Commit dd22838

Browse files
fix[decimal]: add sum support for decimals to chunked and constant (#5023)
The fuzzer found these. --------- Signed-off-by: Joe Isaacs <[email protected]>
1 parent acb0120 commit dd22838

File tree

8 files changed

+874
-30
lines changed

8 files changed

+874
-30
lines changed

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 168 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use num_traits::PrimInt;
5-
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
6-
use vortex_error::{VortexExpect, VortexResult, vortex_err};
7-
use vortex_scalar::{FromPrimitiveOrF16, Scalar};
5+
use vortex_dtype::Nullability::Nullable;
6+
use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype};
7+
use vortex_error::{VortexResult, vortex_bail, vortex_err};
8+
use vortex_scalar::{DecimalScalar, DecimalValue, FromPrimitiveOrF16, Scalar, i256};
89

910
use crate::arrays::{ChunkedArray, ChunkedVTable};
1011
use crate::compute::{SumKernel, SumKernelAdapter, sum};
@@ -16,16 +17,23 @@ impl SumKernel for ChunkedVTable {
1617
let sum_dtype = Stat::Sum
1718
.dtype(array.dtype())
1819
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19-
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
2020

21-
let scalar_value = match_each_native_ptype!(
22-
sum_ptype,
23-
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
24-
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
25-
floating: |T| { sum_float(array.chunks())?.into() }
26-
);
21+
match sum_dtype {
22+
DType::Decimal(decimal_dtype, _) => sum_decimal(array.chunks(), decimal_dtype),
23+
DType::Primitive(sum_ptype, _) => {
24+
let scalar_value = match_each_native_ptype!(
25+
sum_ptype,
26+
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
27+
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
28+
floating: |T| { sum_float(array.chunks())?.into() }
29+
);
2730

28-
Ok(Scalar::new(sum_dtype, scalar_value))
31+
Ok(Scalar::new(sum_dtype, scalar_value))
32+
}
33+
_ => {
34+
vortex_bail!("Sum not supported for dtype {}", sum_dtype);
35+
}
36+
}
2937
}
3038
}
3139

@@ -39,7 +47,7 @@ fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
3947
let chunk_sum = sum(chunk)?;
4048

4149
let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
42-
// Bail out on overflow
50+
// Bail out missing statistic
4351
return Ok(None);
4452
};
4553

@@ -63,14 +71,46 @@ fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
6371
Ok(result)
6472
}
6573

74+
fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
75+
let mut result = DecimalValue::I256(i256::ZERO);
76+
77+
let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable));
78+
79+
for chunk in chunks {
80+
let chunk_sum = sum(chunk)?;
81+
82+
let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
83+
let Some(chunk_value) = chunk_decimal.decimal_value() else {
84+
// skips all null chunks
85+
continue;
86+
};
87+
88+
// Perform checked addition with current result
89+
let Some(r) = result.checked_add(&chunk_value).filter(|sum_value| {
90+
sum_value
91+
.fits_in_precision(result_decimal_type)
92+
.unwrap_or(false)
93+
}) else {
94+
// Overflow
95+
return Ok(null());
96+
};
97+
98+
result = r;
99+
}
100+
101+
Ok(Scalar::decimal(result, result_decimal_type, Nullable))
102+
}
103+
66104
#[cfg(test)]
67105
mod tests {
68-
use vortex_dtype::Nullability;
69-
use vortex_scalar::Scalar;
106+
use vortex_buffer::buffer;
107+
use vortex_dtype::{DType, DecimalDType, Nullability};
108+
use vortex_scalar::{DecimalValue, Scalar, i256};
70109

71110
use crate::array::IntoArray;
72-
use crate::arrays::{ChunkedArray, ConstantArray, PrimitiveArray};
111+
use crate::arrays::{ChunkedArray, ConstantArray, DecimalArray, PrimitiveArray};
73112
use crate::compute::sum;
113+
use crate::validity::Validity;
74114

75115
#[test]
76116
fn test_sum_chunked_floats_with_nulls() {
@@ -138,4 +178,117 @@ mod tests {
138178
let result = sum(chunked.as_ref()).unwrap();
139179
assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
140180
}
181+
182+
#[test]
183+
fn test_sum_chunked_decimals() {
184+
// Create decimal chunks with precision=10, scale=2
185+
let decimal_dtype = DecimalDType::new(10, 2);
186+
let chunk1 = DecimalArray::new(
187+
buffer![100i32, 100i32, 100i32, 100i32, 100i32],
188+
decimal_dtype,
189+
Validity::AllValid,
190+
);
191+
let chunk2 = DecimalArray::new(
192+
buffer![200i32, 200i32, 200i32],
193+
decimal_dtype,
194+
Validity::AllValid,
195+
);
196+
let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
197+
198+
let dtype = chunk1.dtype().clone();
199+
let chunked = ChunkedArray::try_new(
200+
vec![
201+
chunk1.into_array(),
202+
chunk2.into_array(),
203+
chunk3.into_array(),
204+
],
205+
dtype,
206+
)
207+
.unwrap();
208+
209+
// Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00)
210+
let result = sum(chunked.as_ref()).unwrap();
211+
let decimal_result = result.as_decimal();
212+
assert_eq!(
213+
decimal_result.decimal_value(),
214+
Some(DecimalValue::I256(i256::from_i128(1700)))
215+
);
216+
}
217+
218+
#[test]
219+
fn test_sum_chunked_decimals_with_nulls() {
220+
let decimal_dtype = DecimalDType::new(10, 2);
221+
222+
// Create chunks with some nulls - all must have same nullability
223+
let chunk1 = DecimalArray::new(
224+
buffer![100i32, 100i32, 100i32],
225+
decimal_dtype,
226+
Validity::AllValid,
227+
);
228+
let chunk2 = DecimalArray::new(
229+
buffer![0i32, 0i32],
230+
decimal_dtype,
231+
Validity::from_iter([false, false]),
232+
);
233+
let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
234+
235+
let dtype = chunk1.dtype().clone();
236+
let chunked = ChunkedArray::try_new(
237+
vec![
238+
chunk1.into_array(),
239+
chunk2.into_array(),
240+
chunk3.into_array(),
241+
],
242+
dtype,
243+
)
244+
.unwrap();
245+
246+
// Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored)
247+
let result = sum(chunked.as_ref()).unwrap();
248+
let decimal_result = result.as_decimal();
249+
assert_eq!(
250+
decimal_result.decimal_value(),
251+
Some(DecimalValue::I256(i256::from_i128(700)))
252+
);
253+
}
254+
255+
#[test]
256+
fn test_sum_chunked_decimals_large() {
257+
// Create decimals with precision 3 (max value 999)
258+
// Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10)
259+
let decimal_dtype = DecimalDType::new(3, 0);
260+
let chunk1 = ConstantArray::new(
261+
Scalar::decimal(
262+
DecimalValue::I16(500),
263+
decimal_dtype,
264+
Nullability::NonNullable,
265+
),
266+
1,
267+
);
268+
let chunk2 = ConstantArray::new(
269+
Scalar::decimal(
270+
DecimalValue::I16(600),
271+
decimal_dtype,
272+
Nullability::NonNullable,
273+
),
274+
1,
275+
);
276+
277+
let dtype = chunk1.dtype().clone();
278+
let chunked =
279+
ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
280+
281+
// Compute sum: 500 + 600 = 1100
282+
// Result should have precision 13 (3+10), scale 0
283+
let result = sum(chunked.as_ref()).unwrap();
284+
let decimal_result = result.as_decimal();
285+
assert_eq!(
286+
decimal_result.decimal_value(),
287+
Some(DecimalValue::I256(i256::from_i128(1100)))
288+
);
289+
assert_eq!(
290+
result.dtype(),
291+
&DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable)
292+
);
293+
}
141294
}

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use num_traits::{CheckedMul, ToPrimitive};
5-
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6-
use vortex_error::{VortexResult, vortex_bail, vortex_err};
7-
use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue};
5+
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, match_each_native_ptype};
6+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7+
use vortex_scalar::{
8+
DecimalScalar, DecimalValue, FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue, i256,
9+
};
810

911
use crate::arrays::{ConstantArray, ConstantVTable};
1012
use crate::compute::{SumKernel, SumKernelAdapter};
@@ -36,11 +38,47 @@ fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
3638
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
3739
floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
3840
)),
41+
DType::Decimal(decimal_dtype, _) => sum_decimal(scalar.as_decimal(), len, *decimal_dtype),
3942
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
4043
dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
4144
}
4245
}
4346

47+
fn sum_decimal(
48+
decimal_scalar: DecimalScalar,
49+
array_len: usize,
50+
decimal_dtype: DecimalDType,
51+
) -> VortexResult<ScalarValue> {
52+
let result_dtype = Stat::Sum
53+
.dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
54+
.vortex_expect("decimal supports sum");
55+
let result_decimal_type = result_dtype
56+
.as_decimal_opt()
57+
.vortex_expect("must be decimal");
58+
59+
let Some(value) = decimal_scalar.decimal_value() else {
60+
// Null value: return null
61+
return Ok(ScalarValue::null());
62+
};
63+
64+
// Convert array_len to DecimalValue for multiplication
65+
let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
66+
67+
// Multiply value * len
68+
let sum = value.checked_mul(&len_value).and_then(|result| {
69+
// Check if result fits in the precision
70+
result
71+
.fits_in_precision(*result_decimal_type)
72+
.unwrap_or(false)
73+
.then_some(result)
74+
});
75+
76+
match sum {
77+
Some(result_value) => Ok(ScalarValue::from(result_value)),
78+
None => Ok(ScalarValue::null()), // Overflow
79+
}
80+
}
81+
4482
fn sum_integral<T>(
4583
primitive_scalar: PrimitiveScalar<'_>,
4684
array_len: usize,
@@ -70,12 +108,13 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());
70108

71109
#[cfg(test)]
72110
mod tests {
73-
use vortex_dtype::{DType, Nullability, PType};
74-
use vortex_scalar::Scalar;
111+
use vortex_dtype::{DType, DecimalDType, Nullability, PType};
112+
use vortex_scalar::{DecimalValue, Scalar};
75113

76-
use crate::IntoArray;
77114
use crate::arrays::ConstantArray;
78115
use crate::compute::sum;
116+
use crate::stats::Stat;
117+
use crate::{Array, IntoArray};
79118

80119
#[test]
81120
fn test_sum_unsigned() {
@@ -123,4 +162,61 @@ mod tests {
123162
let result = sum(&array).unwrap();
124163
assert!(result.is_null());
125164
}
165+
166+
#[test]
167+
fn test_sum_decimal() {
168+
let decimal_dtype = DecimalDType::new(10, 2);
169+
let array = ConstantArray::new(
170+
Scalar::decimal(
171+
DecimalValue::I64(100),
172+
decimal_dtype,
173+
Nullability::NonNullable,
174+
),
175+
5,
176+
)
177+
.into_array();
178+
179+
let result = sum(&array).unwrap();
180+
181+
assert_eq!(
182+
result.as_decimal().decimal_value(),
183+
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(500)))
184+
);
185+
assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
186+
}
187+
188+
#[test]
189+
fn test_sum_decimal_null() {
190+
let decimal_dtype = DecimalDType::new(10, 2);
191+
let array = ConstantArray::new(
192+
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)),
193+
10,
194+
)
195+
.into_array();
196+
197+
let result = sum(&array).unwrap();
198+
assert!(result.is_null());
199+
}
200+
201+
#[test]
202+
fn test_sum_decimal_large_value() {
203+
let decimal_dtype = DecimalDType::new(10, 2);
204+
let array = ConstantArray::new(
205+
Scalar::decimal(
206+
DecimalValue::I64(999_999_999),
207+
decimal_dtype,
208+
Nullability::NonNullable,
209+
),
210+
100,
211+
)
212+
.into_array();
213+
214+
let result = sum(&array).unwrap();
215+
assert_eq!(
216+
result.as_decimal().decimal_value(),
217+
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(
218+
99_999_999_900
219+
)))
220+
);
221+
}
126222
}

vortex-array/src/arrays/decimal/compute/sum.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use arrow_schema::DECIMAL256_MAX_PRECISION;
55
use num_traits::AsPrimitive;
66
use vortex_dtype::DecimalDType;
7+
use vortex_dtype::Nullability::Nullable;
78
use vortex_error::{VortexResult, vortex_bail};
89
use vortex_mask::Mask;
910
use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
@@ -40,7 +41,6 @@ impl SumKernel for DecimalVTable {
4041
#[allow(clippy::cognitive_complexity)]
4142
fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
4243
let decimal_dtype = array.decimal_dtype();
43-
let nullability = array.dtype().nullability();
4444

4545
// Both Spark and DataFusion use this heuristic.
4646
// - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -60,7 +60,7 @@ impl SumKernel for DecimalVTable {
6060
Ok(Scalar::decimal(
6161
DecimalValue::from(sum_decimal!(O, array.buffer::<I>())),
6262
return_dtype,
63-
nullability,
63+
Nullable,
6464
))
6565
})
6666
})
@@ -76,7 +76,7 @@ impl SumKernel for DecimalVTable {
7676
mask_values.boolean_buffer()
7777
)),
7878
return_dtype,
79-
nullability,
79+
Nullable,
8080
))
8181
})
8282
})

0 commit comments

Comments
 (0)