Skip to content

Commit 87bfc99

Browse files
committed
Add arithmetic buffer compute
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 80492ef commit 87bfc99

File tree

14 files changed

+557
-369
lines changed

14 files changed

+557
-369
lines changed

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/compute/arrays/arithmetic.rs

Lines changed: 150 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,62 +4,59 @@
44
use std::hash::{Hash, Hasher};
55
use std::sync::LazyLock;
66

7-
use enum_map::{Enum, EnumMap, enum_map};
7+
use enum_map::{enum_map, Enum, EnumMap};
88
use vortex_buffer::ByteBuffer;
9-
use vortex_compute::logical::{
10-
LogicalAnd, LogicalAndKleene, LogicalAndNot, LogicalOr, LogicalOrKleene,
11-
};
12-
use vortex_dtype::DType;
13-
use vortex_error::VortexResult;
14-
use vortex_vector::BoolVector;
9+
use vortex_compute::arithmetic::{Add, Checked, CheckedOperator, Div, Mul, Sub};
10+
use vortex_dtype::{match_each_native_ptype, DType, NativePType, PTypeDowncastExt};
11+
use vortex_error::{vortex_err, VortexExpect, VortexResult};
12+
use vortex_scalar::{PValue, Scalar};
1513

16-
use crate::execution::{BatchKernelRef, BindCtx, kernel};
14+
use crate::arrays::ConstantArray;
15+
use crate::execution::{kernel, BatchKernelRef, BindCtx};
1716
use crate::serde::ArrayChildren;
1817
use crate::stats::{ArrayStats, StatsSetRef};
1918
use crate::vtable::{
2019
ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable,
2120
};
2221
use crate::{
23-
Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef,
24-
DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, Precision, vtable,
22+
vtable, Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash,
23+
ArrayRef, DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, IntoArray, Precision,
2524
};
2625

27-
/// The set of operators supported by a logical array.
26+
/// The set of operators supported by an arithmetic array.
2827
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)]
29-
pub enum LogicalOperator {
30-
/// Logical AND
31-
And,
32-
/// Logical AND with Kleene logic
33-
AndKleene,
34-
/// Logical OR
35-
Or,
36-
/// Logical OR with Kleene logic
37-
OrKleene,
38-
/// Logical AND NOT
39-
AndNot,
28+
pub enum ArithmeticOperator {
29+
/// Addition
30+
Add,
31+
/// Subtraction
32+
Sub,
33+
/// Multiplication
34+
Mul,
35+
/// Division
36+
Div,
4037
}
4138

42-
vtable!(Logical);
39+
vtable!(Arithmetic);
4340

4441
#[derive(Debug, Clone)]
45-
pub struct LogicalArray {
42+
pub struct ArithmeticArray {
4643
encoding: EncodingRef,
4744
lhs: ArrayRef,
4845
rhs: ArrayRef,
4946
stats: ArrayStats,
5047
}
5148

52-
impl LogicalArray {
49+
impl ArithmeticArray {
5350
/// Create a new logical array.
54-
pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: LogicalOperator) -> Self {
51+
pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: ArithmeticOperator) -> Self {
5552
assert_eq!(
5653
lhs.len(),
5754
rhs.len(),
58-
"Logical arrays require lhs and rhs to have the same length"
55+
"Arithmetic arrays require lhs and rhs to have the same length"
5956
);
6057

6158
// TODO(ngates): should we automatically cast non-null to nullable if required?
62-
assert!(matches!(lhs.dtype(), DType::Bool(_)));
59+
assert!(matches!(lhs.dtype(), DType::Primitive(..)));
6360
assert_eq!(lhs.dtype(), rhs.dtype());
6461

6562
Self {
@@ -71,29 +68,29 @@ impl LogicalArray {
7168
}
7269

7370
/// Returns the operator of this logical array.
74-
pub fn operator(&self) -> LogicalOperator {
75-
self.encoding.as_::<LogicalVTable>().operator
71+
pub fn operator(&self) -> ArithmeticOperator {
72+
self.encoding.as_::<ArithmeticVTable>().operator
7673
}
7774
}
7875

7976
#[derive(Debug, Clone)]
80-
pub struct LogicalEncoding {
77+
pub struct ArithmeticEncoding {
8178
// We include the operator in the encoding so each operator is a different encoding ID.
8279
// This makes it easier for plugins to construct expressions and perform pushdown
8380
// optimizations.
84-
operator: LogicalOperator,
81+
operator: ArithmeticOperator,
8582
}
8683

8784
#[allow(clippy::mem_forget)]
88-
static ENCODINGS: LazyLock<EnumMap<LogicalOperator, EncodingRef>> = LazyLock::new(|| {
85+
static ENCODINGS: LazyLock<EnumMap<ArithmeticOperator, EncodingRef>> = LazyLock::new(|| {
8986
enum_map! {
90-
operator => LogicalEncoding { operator }.to_encoding(),
87+
operator => ArithmeticEncoding { operator }.to_encoding(),
9188
}
9289
});
9390

94-
impl VTable for LogicalVTable {
95-
type Array = LogicalArray;
96-
type Encoding = LogicalEncoding;
91+
impl VTable for ArithmeticVTable {
92+
type Array = ArithmeticArray;
93+
type Encoding = ArithmeticEncoding;
9794
type ArrayVTable = Self;
9895
type CanonicalVTable = NotSupported;
9996
type OperationsVTable = NotSupported;
@@ -106,11 +103,10 @@ impl VTable for LogicalVTable {
106103

107104
fn id(encoding: &Self::Encoding) -> EncodingId {
108105
match encoding.operator {
109-
LogicalOperator::And => EncodingId::from("vortex.and"),
110-
LogicalOperator::AndKleene => EncodingId::from("vortex.and_kleene"),
111-
LogicalOperator::Or => EncodingId::from("vortex.or"),
112-
LogicalOperator::OrKleene => EncodingId::from("vortex.or_kleene"),
113-
LogicalOperator::AndNot => EncodingId::from("vortex.and_not"),
106+
ArithmeticOperator::Add => EncodingId::from("vortex.add"),
107+
ArithmeticOperator::Sub => EncodingId::from("vortex.sub"),
108+
ArithmeticOperator::Mul => EncodingId::from("vortex.mul"),
109+
ArithmeticOperator::Div => EncodingId::from("vortex.div"),
114110
}
115111
}
116112

@@ -119,104 +115,187 @@ impl VTable for LogicalVTable {
119115
}
120116
}
121117

122-
impl ArrayVTable<LogicalVTable> for LogicalVTable {
123-
fn len(array: &LogicalArray) -> usize {
118+
impl ArrayVTable<ArithmeticVTable> for ArithmeticVTable {
119+
fn len(array: &ArithmeticArray) -> usize {
124120
array.lhs.len()
125121
}
126122

127-
fn dtype(array: &LogicalArray) -> &DType {
123+
fn dtype(array: &ArithmeticArray) -> &DType {
128124
array.lhs.dtype()
129125
}
130126

131-
fn stats(array: &LogicalArray) -> StatsSetRef<'_> {
127+
fn stats(array: &ArithmeticArray) -> StatsSetRef<'_> {
132128
array.stats.to_ref(array.as_ref())
133129
}
134130

135-
fn array_hash<H: Hasher>(array: &LogicalArray, state: &mut H, precision: Precision) {
131+
fn array_hash<H: Hasher>(array: &ArithmeticArray, state: &mut H, precision: Precision) {
136132
array.lhs.array_hash(state, precision);
137133
array.rhs.array_hash(state, precision);
138134
}
139135

140-
fn array_eq(array: &LogicalArray, other: &LogicalArray, precision: Precision) -> bool {
136+
fn array_eq(array: &ArithmeticArray, other: &ArithmeticArray, precision: Precision) -> bool {
141137
array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision)
142138
}
143139
}
144140

145-
impl VisitorVTable<LogicalVTable> for LogicalVTable {
146-
fn visit_buffers(_array: &LogicalArray, _visitor: &mut dyn ArrayBufferVisitor) {
141+
impl VisitorVTable<ArithmeticVTable> for ArithmeticVTable {
142+
fn visit_buffers(_array: &ArithmeticArray, _visitor: &mut dyn ArrayBufferVisitor) {
147143
// No buffers
148144
}
149145

150-
fn visit_children(array: &LogicalArray, visitor: &mut dyn ArrayChildVisitor) {
146+
fn visit_children(array: &ArithmeticArray, visitor: &mut dyn ArrayChildVisitor) {
151147
visitor.visit_child("lhs", array.lhs.as_ref());
152148
visitor.visit_child("rhs", array.rhs.as_ref());
153149
}
154150
}
155151

156-
impl SerdeVTable<LogicalVTable> for LogicalVTable {
152+
impl SerdeVTable<ArithmeticVTable> for ArithmeticVTable {
157153
type Metadata = EmptyMetadata;
158154

159-
fn metadata(_array: &LogicalArray) -> VortexResult<Option<Self::Metadata>> {
155+
fn metadata(_array: &ArithmeticArray) -> VortexResult<Option<Self::Metadata>> {
160156
Ok(Some(EmptyMetadata))
161157
}
162158

163159
fn build(
164-
encoding: &LogicalEncoding,
160+
encoding: &ArithmeticEncoding,
165161
dtype: &DType,
166162
len: usize,
167163
_metadata: &<Self::Metadata as DeserializeMetadata>::Output,
168164
buffers: &[ByteBuffer],
169165
children: &dyn ArrayChildren,
170-
) -> VortexResult<LogicalArray> {
166+
) -> VortexResult<ArithmeticArray> {
171167
assert!(buffers.is_empty());
172-
Ok(LogicalArray::new(
168+
169+
Ok(ArithmeticArray::new(
173170
children.get(0, dtype, len)?,
174171
children.get(1, dtype, len)?,
175172
encoding.operator,
176173
))
177174
}
178175
}
179176

180-
impl OperatorVTable<LogicalVTable> for LogicalVTable {
177+
impl OperatorVTable<ArithmeticVTable> for ArithmeticVTable {
178+
fn reduce_children(array: &ArithmeticArray) -> VortexResult<Option<ArrayRef>> {
179+
match (array.lhs.as_constant(), array.rhs.as_constant()) {
180+
// If both sides are constant, we compute the value now.
181+
(Some(lhs), Some(rhs)) => {
182+
let op: vortex_scalar::NumericOperator = match array.operator() {
183+
ArithmeticOperator::Add => vortex_scalar::NumericOperator::Add,
184+
ArithmeticOperator::Sub => vortex_scalar::NumericOperator::Sub,
185+
ArithmeticOperator::Mul => vortex_scalar::NumericOperator::Mul,
186+
ArithmeticOperator::Div => vortex_scalar::NumericOperator::Div,
187+
};
188+
let result = lhs
189+
.as_primitive()
190+
.checked_binary_numeric(&rhs.as_primitive(), op)
191+
.ok_or_else(|| {
192+
vortex_err!("Constant arithmetic operation resulted in overflow")
193+
})?;
194+
return Ok(Some(
195+
ConstantArray::new(Scalar::from(result), array.len()).into_array(),
196+
));
197+
}
198+
// If either side is constant null, the result is constant null.
199+
(Some(lhs), _) if lhs.is_null() => {
200+
return Ok(Some(
201+
ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
202+
.into_array(),
203+
));
204+
}
205+
(_, Some(rhs)) if rhs.is_null() => {
206+
return Ok(Some(
207+
ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
208+
.into_array(),
209+
));
210+
}
211+
_ => {}
212+
}
213+
214+
Ok(None)
215+
}
216+
181217
fn bind(
182-
array: &LogicalArray,
218+
array: &ArithmeticArray,
183219
selection: Option<&ArrayRef>,
184220
ctx: &mut dyn BindCtx,
185221
) -> VortexResult<BatchKernelRef> {
222+
// Optimize for constant RHS
223+
if let Some(rhs) = array.rhs.as_constant() {
224+
if rhs.is_null() {
225+
// If the RHS is null, the result is always null.
226+
return Ok(
227+
ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
228+
.into_array()
229+
.bind(selection, ctx)?,
230+
);
231+
}
232+
233+
let lhs = ctx.bind(&array.lhs, selection)?;
234+
return match_each_native_ptype!(array.dtype().as_ptype(), |T| {
235+
let rhs_value: T = rhs
236+
.as_primitive()
237+
.typed_value::<T>()
238+
.vortex_expect("Already checked for null above");
239+
Ok(match array.operator() {
240+
ArithmeticOperator::Add => arithmetic_scalar_kernel::<Add, _>(lhs, rhs_value),
241+
ArithmeticOperator::Sub => arithmetic_scalar_kernel::<Sub, _>(lhs, rhs_value),
242+
ArithmeticOperator::Mul => arithmetic_scalar_kernel::<Mul, _>(lhs, rhs_value),
243+
ArithmeticOperator::Div => arithmetic_scalar_kernel::<Div, _>(lhs, rhs_value),
244+
})
245+
});
246+
}
247+
186248
let lhs = ctx.bind(&array.lhs, selection)?;
187249
let rhs = ctx.bind(&array.rhs, selection)?;
188250

189-
Ok(match array.operator() {
190-
LogicalOperator::And => logical_kernel(lhs, rhs, |l, r| l.and(&r)),
191-
LogicalOperator::AndKleene => logical_kernel(lhs, rhs, |l, r| l.and_kleene(&r)),
192-
LogicalOperator::Or => logical_kernel(lhs, rhs, |l, r| l.or(&r)),
193-
LogicalOperator::OrKleene => logical_kernel(lhs, rhs, |l, r| l.or_kleene(&r)),
194-
LogicalOperator::AndNot => logical_kernel(lhs, rhs, |l, r| l.and_not(&r)),
251+
match_each_native_ptype!(array.dtype().as_ptype(), |T| {
252+
Ok(match array.operator() {
253+
ArithmeticOperator::Add => arithmetic_kernel::<Add, T>(lhs, rhs),
254+
ArithmeticOperator::Sub => arithmetic_kernel::<Sub, T>(lhs, rhs),
255+
ArithmeticOperator::Mul => arithmetic_kernel::<Mul, T>(lhs, rhs),
256+
ArithmeticOperator::Div => arithmetic_kernel::<Div, T>(lhs, rhs),
257+
})
195258
})
196259
}
197260
}
198261

199262
/// Batch execution kernel for logical operations.
200-
fn logical_kernel<O>(lhs: BatchKernelRef, rhs: BatchKernelRef, op: O) -> BatchKernelRef
263+
fn arithmetic_kernel<Op, T>(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef
264+
where
265+
T: NativePType,
266+
Op: CheckedOperator<T>,
267+
{
268+
kernel(move || {
269+
let lhs = lhs.execute()?.into_primitive().downcast::<T>();
270+
let rhs = rhs.execute()?.into_primitive().downcast::<T>();
271+
let result = Checked::<Op, _>::checked_op(lhs, &rhs)
272+
.ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
273+
Ok(result.into())
274+
})
275+
}
276+
277+
fn arithmetic_scalar_kernel<Op, T>(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef
201278
where
202-
O: Fn(BoolVector, BoolVector) -> BoolVector + Send + 'static,
279+
T: NativePType + TryFrom<PValue>,
280+
Op: CheckedOperator<T>,
203281
{
204282
kernel(move || {
205-
let lhs = lhs.execute()?.into_bool();
206-
let rhs = rhs.execute()?.into_bool();
207-
Ok(op(lhs, rhs).into())
283+
let lhs = lhs.execute()?.into_primitive().downcast::<T>();
284+
let result = Checked::<Op, _>::checked_op(lhs, &rhs)
285+
.ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
286+
Ok(result.into())
208287
})
209288
}
210289

211290
#[cfg(test)]
212291
mod tests {
213292
use vortex_buffer::bitbuffer;
214293

215-
use crate::compute::arrays::logical::{LogicalArray, LogicalOperator};
294+
use crate::compute::arrays::logical::ArithmeticOperator;
216295
use crate::{ArrayOperator, ArrayRef, IntoArray};
217296

218297
fn and_(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
219-
LogicalArray::new(lhs, rhs, LogicalOperator::And).into_array()
298+
ArithmeticArray::new(lhs, rhs, ArithmeticOperator::And).into_array()
220299
}
221300

222301
#[test]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
mod logical;
4+
pub mod arithmetic;
5+
pub mod logical;

vortex-array/src/compute/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::{Array, ArrayRef};
4343

4444
#[cfg(feature = "arbitrary")]
4545
mod arbitrary;
46-
mod arrays;
46+
pub mod arrays;
4747
mod between;
4848
mod boolean;
4949
mod cast;

vortex-array/src/execution/batch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ impl<F: FnOnce() -> VortexResult<Vector> + Send + 'static> BatchKernel for Batch
2323
}
2424

2525
/// Create a batch execution kernel from the given closure.
26+
#[inline(always)]
2627
pub fn kernel<F: FnOnce() -> VortexResult<Vector> + Send + 'static>(f: F) -> BatchKernelRef {
2728
Box::new(BatchKernelAdapter(f))
2829
}

0 commit comments

Comments
 (0)