Skip to content

Commit fa08a07

Browse files
authored
feat: add BinaryNumericFn for array arithmetic (#1640)
I did not implement any binary numeric functions because it is not clear that there are any cases where we can out run decompression. Two run end arrays might be a happy path? Two dictionaries, maybe, if the dictionaries are much smaller than the decompressed arrays? Binary scalar numeric functions are more obviously valuable: clickbench includes several uses of scalar add or subtract.
1 parent e69bde8 commit fa08a07

File tree

23 files changed

+570
-268
lines changed

23 files changed

+570
-268
lines changed

bench-vortex/src/bin/notimplemented.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
174174
"fill_forward",
175175
"filter",
176176
"scalar_at",
177-
"subtract_scalar",
177+
"binary_numeric",
178178
"search_sorted",
179179
"slice",
180180
"take",
@@ -190,7 +190,7 @@ fn compute_funcs(encodings: &[ArrayData]) {
190190
impls.push(bool_to_cell(arr.encoding().fill_forward_fn().is_some()));
191191
impls.push(bool_to_cell(arr.encoding().filter_fn().is_some()));
192192
impls.push(bool_to_cell(arr.encoding().scalar_at_fn().is_some()));
193-
impls.push(bool_to_cell(arr.encoding().subtract_scalar_fn().is_some()));
193+
impls.push(bool_to_cell(arr.encoding().binary_numeric_fn().is_some()));
194194
impls.push(bool_to_cell(arr.encoding().search_sorted_fn().is_some()));
195195
impls.push(bool_to_cell(arr.encoding().slice_fn().is_some()));
196196
impls.push(bool_to_cell(arr.encoding().take_fn().is_some()));

encodings/dict/src/compute/mod.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@ mod compare;
22
mod like;
33

44
use vortex_array::compute::{
5-
filter, scalar_at, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask, LikeFn,
6-
ScalarAtFn, SliceFn, TakeFn,
5+
binary_numeric, filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable,
6+
FilterFn, FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn,
77
};
88
use vortex_array::{ArrayData, IntoArrayData};
99
use vortex_error::VortexResult;
10-
use vortex_scalar::Scalar;
10+
use vortex_scalar::{BinaryNumericOperator, Scalar};
1111

1212
use crate::{DictArray, DictEncoding};
1313

1414
impl ComputeVTable for DictEncoding {
15+
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
16+
Some(self)
17+
}
18+
1519
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
1620
Some(self)
1721
}
@@ -37,6 +41,23 @@ impl ComputeVTable for DictEncoding {
3741
}
3842
}
3943

44+
impl BinaryNumericFn<DictArray> for DictEncoding {
45+
fn binary_numeric(
46+
&self,
47+
array: &DictArray,
48+
rhs: &ArrayData,
49+
op: BinaryNumericOperator,
50+
) -> VortexResult<Option<ArrayData>> {
51+
if !rhs.is_constant() {
52+
return Ok(None);
53+
}
54+
55+
DictArray::try_new(array.codes(), binary_numeric(&array.values(), rhs, op)?)
56+
.map(IntoArrayData::into_array)
57+
.map(Some)
58+
}
59+
}
60+
4061
impl ScalarAtFn<DictArray> for DictEncoding {
4162
fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
4263
let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;

encodings/runend/src/compute/mod.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@ use std::ops::AddAssign;
99
use num_traits::AsPrimitive;
1010
use vortex_array::array::{BooleanBuffer, PrimitiveArray};
1111
use vortex_array::compute::{
12-
filter, scalar_at, slice, CompareFn, ComputeVTable, FillNullFn, FilterFn, FilterMask, InvertFn,
13-
ScalarAtFn, SliceFn, TakeFn,
12+
binary_numeric, filter, scalar_at, slice, BinaryNumericFn, CompareFn, ComputeVTable,
13+
FillNullFn, FilterFn, FilterMask, InvertFn, ScalarAtFn, SliceFn, TakeFn,
1414
};
1515
use vortex_array::variants::PrimitiveArrayTrait;
1616
use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
1717
use vortex_dtype::{match_each_unsigned_integer_ptype, NativePType};
1818
use vortex_error::{VortexResult, VortexUnwrap};
19-
use vortex_scalar::Scalar;
19+
use vortex_scalar::{BinaryNumericOperator, Scalar};
2020

2121
use crate::{RunEndArray, RunEndEncoding};
2222

2323
impl ComputeVTable for RunEndEncoding {
24+
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
25+
Some(self)
26+
}
27+
2428
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
2529
Some(self)
2630
}
@@ -50,6 +54,28 @@ impl ComputeVTable for RunEndEncoding {
5054
}
5155
}
5256

57+
impl BinaryNumericFn<RunEndArray> for RunEndEncoding {
58+
fn binary_numeric(
59+
&self,
60+
array: &RunEndArray,
61+
rhs: &ArrayData,
62+
op: BinaryNumericOperator,
63+
) -> VortexResult<Option<ArrayData>> {
64+
if !rhs.is_constant() {
65+
return Ok(None);
66+
}
67+
68+
RunEndArray::with_offset_and_length(
69+
array.ends(),
70+
binary_numeric(&array.values(), rhs, op)?,
71+
array.offset(),
72+
array.len(),
73+
)
74+
.map(IntoArrayData::into_array)
75+
.map(Some)
76+
}
77+
}
78+
5379
impl ScalarAtFn<RunEndArray> for RunEndEncoding {
5480
fn scalar_at(&self, array: &RunEndArray, index: usize) -> VortexResult<Scalar> {
5581
scalar_at(array.values(), array.find_physical_index(index)?)

vortex-array/benches/scalar_subtract.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ fn scalar_subtract(c: &mut Criterion) {
2828

2929
group.bench_function("vortex", |b| {
3030
b.iter(|| {
31-
let array =
32-
vortex_array::compute::subtract_scalar(&chunked, &to_subtract.into()).unwrap();
31+
let array = vortex_array::compute::sub_scalar(&chunked, to_subtract.into()).unwrap();
3332

3433
let chunked = ChunkedArray::try_from(array).unwrap();
3534
black_box(chunked);

vortex-array/src/array/chunked/compute/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use vortex_error::VortexResult;
44
use crate::array::chunked::ChunkedArray;
55
use crate::array::ChunkedEncoding;
66
use crate::compute::{
7-
try_cast, BinaryBooleanFn, CastFn, CompareFn, ComputeVTable, FillNullFn, FilterFn, InvertFn,
8-
ScalarAtFn, SliceFn, SubtractScalarFn, TakeFn,
7+
try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, ComputeVTable, FillNullFn,
8+
FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn,
99
};
1010
use crate::{ArrayData, IntoArrayData};
1111

@@ -23,6 +23,10 @@ impl ComputeVTable for ChunkedEncoding {
2323
Some(self)
2424
}
2525

26+
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
27+
Some(self)
28+
}
29+
2630
fn cast_fn(&self) -> Option<&dyn CastFn<ArrayData>> {
2731
Some(self)
2832
}
@@ -51,10 +55,6 @@ impl ComputeVTable for ChunkedEncoding {
5155
Some(self)
5256
}
5357

54-
fn subtract_scalar_fn(&self) -> Option<&dyn SubtractScalarFn<ArrayData>> {
55-
Some(self)
56-
}
57-
5858
fn take_fn(&self) -> Option<&dyn TakeFn<ArrayData>> {
5959
Some(self)
6060
}

vortex-array/src/array/chunked/compute/take.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use vortex_scalar::Scalar;
66
use crate::array::chunked::ChunkedArray;
77
use crate::array::ChunkedEncoding;
88
use crate::compute::{
9-
scalar_at, search_sorted_usize, slice, subtract_scalar, take, try_cast, SearchSortedSide,
10-
TakeFn,
9+
scalar_at, search_sorted_usize, slice, sub_scalar, take, try_cast, SearchSortedSide, TakeFn,
1110
};
1211
use crate::stats::ArrayStatistics;
1312
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
@@ -93,15 +92,15 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResu
9392
.max_value_as_u64()
9493
.try_into()?
9594
{
96-
subtract_scalar(
95+
sub_scalar(
9796
&chunk_indices,
98-
&Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
97+
Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
9998
)?
10099
} else {
101100
// Note. this try_cast (memory copy) is unnecessary, could instead upcast in the subtract fn.
102101
// and avoid an extra
103102
let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?;
104-
subtract_scalar(&u64_chunk_indices, &chunk_begin.into())?
103+
sub_scalar(&u64_chunk_indices, chunk_begin.into())?
105104
};
106105

107106
indices_by_chunk[chunk_idx] = Some(chunk_indices);

vortex-array/src/array/chunked/mod.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ use itertools::Itertools;
99
use serde::{Deserialize, Serialize};
1010
use vortex_dtype::{DType, Nullability, PType};
1111
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap};
12-
use vortex_scalar::Scalar;
12+
use vortex_scalar::BinaryNumericOperator;
1313

1414
use crate::array::primitive::PrimitiveArray;
1515
use crate::compute::{
16-
scalar_at, search_sorted_usize, subtract_scalar, SearchSortedSide, SubtractScalarFn,
16+
binary_numeric, scalar_at, search_sorted_usize, slice, BinaryNumericFn, SearchSortedSide,
1717
};
1818
use crate::encoding::ids;
1919
use crate::iter::{ArrayIterator, ArrayIteratorAdapter};
@@ -234,17 +234,25 @@ impl ValidityVTable<ChunkedArray> for ChunkedEncoding {
234234
}
235235
}
236236

237-
impl SubtractScalarFn<ChunkedArray> for ChunkedEncoding {
238-
fn subtract_scalar(
237+
impl BinaryNumericFn<ChunkedArray> for ChunkedEncoding {
238+
fn binary_numeric(
239239
&self,
240240
array: &ChunkedArray,
241-
to_subtract: &Scalar,
242-
) -> VortexResult<ArrayData> {
243-
let chunks = array
244-
.chunks()
245-
.map(|chunk| subtract_scalar(&chunk, to_subtract))
246-
.collect::<VortexResult<Vec<_>>>()?;
247-
Ok(ChunkedArray::try_new(chunks, array.dtype().clone())?.into_array())
241+
rhs: &ArrayData,
242+
op: BinaryNumericOperator,
243+
) -> VortexResult<Option<ArrayData>> {
244+
let mut start = 0;
245+
246+
let mut new_chunks = Vec::with_capacity(array.nchunks());
247+
for chunk in array.chunks() {
248+
let end = start + chunk.len();
249+
new_chunks.push(binary_numeric(&chunk, &slice(rhs, start, end)?, op)?);
250+
start = end;
251+
}
252+
253+
ChunkedArray::try_new(new_chunks, array.dtype().clone())
254+
.map(IntoArrayData::into_array)
255+
.map(Some)
248256
}
249257
}
250258

@@ -254,7 +262,7 @@ mod test {
254262
use vortex_error::VortexResult;
255263

256264
use crate::array::chunked::ChunkedArray;
257-
use crate::compute::{scalar_at, subtract_scalar};
265+
use crate::compute::{scalar_at, sub_scalar};
258266
use crate::{assert_arrays_eq, ArrayDType, IntoArrayData, IntoArrayVariant};
259267

260268
fn chunked_array() -> ChunkedArray {
@@ -271,9 +279,9 @@ mod test {
271279

272280
#[test]
273281
fn test_scalar_subtract() {
274-
let chunked = chunked_array();
282+
let chunked = chunked_array().into_array();
275283
let to_subtract = 1u64;
276-
let array = subtract_scalar(&chunked, &to_subtract.into()).unwrap();
284+
let array = sub_scalar(&chunked, to_subtract.into()).unwrap();
277285

278286
let chunked = ChunkedArray::try_from(array).unwrap();
279287
let mut chunks_out = chunked.chunks();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use vortex_error::{vortex_err, VortexResult};
2+
use vortex_scalar::BinaryNumericOperator;
3+
4+
use crate::array::{ConstantArray, ConstantEncoding};
5+
use crate::compute::BinaryNumericFn;
6+
use crate::{ArrayData, ArrayLen as _, IntoArrayData as _};
7+
8+
impl BinaryNumericFn<ConstantArray> for ConstantEncoding {
9+
fn binary_numeric(
10+
&self,
11+
array: &ConstantArray,
12+
rhs: &ArrayData,
13+
op: BinaryNumericOperator,
14+
) -> VortexResult<Option<ArrayData>> {
15+
let Some(rhs) = rhs.as_constant() else {
16+
return Ok(None);
17+
};
18+
19+
Ok(Some(
20+
ConstantArray::new(
21+
array
22+
.scalar()
23+
.as_primitive()
24+
.checked_numeric_operator(rhs.as_primitive(), op)?
25+
.ok_or_else(|| vortex_err!("numeric overflow"))?,
26+
array.len(),
27+
)
28+
.into_array(),
29+
))
30+
}
31+
}

vortex-array/src/array/constant/compute/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod binary_numeric;
12
mod boolean;
23
mod compare;
34
mod invert;
@@ -9,8 +10,8 @@ use vortex_scalar::Scalar;
910
use crate::array::constant::ConstantArray;
1011
use crate::array::ConstantEncoding;
1112
use crate::compute::{
12-
BinaryBooleanFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn,
13-
SearchSortedFn, SliceFn, TakeFn,
13+
BinaryBooleanFn, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, FilterMask, InvertFn,
14+
ScalarAtFn, SearchSortedFn, SliceFn, TakeFn,
1415
};
1516
use crate::{ArrayData, IntoArrayData};
1617

@@ -19,6 +20,10 @@ impl ComputeVTable for ConstantEncoding {
1920
Some(self)
2021
}
2122

23+
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
24+
Some(self)
25+
}
26+
2227
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
2328
Some(self)
2429
}

vortex-array/src/array/null/compute.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use vortex_dtype::{match_each_integer_ptype, DType};
22
use vortex_error::{vortex_bail, VortexResult};
3-
use vortex_scalar::Scalar;
3+
use vortex_scalar::{BinaryNumericOperator, Scalar};
44

55
use crate::array::null::NullArray;
66
use crate::array::NullEncoding;
7-
use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn};
7+
use crate::compute::{BinaryNumericFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn};
88
use crate::variants::PrimitiveArrayTrait;
99
use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
1010

@@ -13,6 +13,10 @@ impl ComputeVTable for NullEncoding {
1313
Some(self)
1414
}
1515

16+
fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<ArrayData>> {
17+
Some(self)
18+
}
19+
1620
fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
1721
Some(self)
1822
}
@@ -22,6 +26,18 @@ impl ComputeVTable for NullEncoding {
2226
}
2327
}
2428

29+
impl BinaryNumericFn<NullArray> for NullEncoding {
30+
fn binary_numeric(
31+
&self,
32+
array: &NullArray,
33+
_rhs: &ArrayData,
34+
_op: BinaryNumericOperator,
35+
) -> VortexResult<Option<ArrayData>> {
36+
// for any arithmetic operation, forall X. NULL op X = NULL
37+
Ok(Some(NullArray::new(array.len()).into_array()))
38+
}
39+
}
40+
2541
impl SliceFn<NullArray> for NullEncoding {
2642
fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult<ArrayData> {
2743
Ok(NullArray::new(stop - start).into_array())

0 commit comments

Comments
 (0)