44use std:: hash:: { Hash , Hasher } ;
55use std:: sync:: LazyLock ;
66
7- use enum_map:: { Enum , EnumMap , enum_map } ;
7+ use enum_map:: { enum_map , Enum , EnumMap } ;
88use vortex_buffer:: ByteBuffer ;
9- use vortex_compute:: logical:: {
10- LogicalAnd , LogicalAndKleene , LogicalAndNot , LogicalOr , LogicalOrKleene ,
11- } ;
12- use vortex_dtype:: DType ;
13- use vortex_error:: VortexResult ;
14- use vortex_vector:: BoolVector ;
9+ use vortex_compute:: arithmetic:: { Add , Checked , CheckedOperator , Div , Mul , Sub } ;
10+ use vortex_dtype:: { match_each_native_ptype, DType , NativePType , PTypeDowncastExt } ;
11+ use vortex_error:: { vortex_err, VortexExpect , VortexResult } ;
12+ use vortex_scalar:: { PValue , Scalar } ;
1513
16- use crate :: execution:: { BatchKernelRef , BindCtx , kernel} ;
14+ use crate :: arrays:: ConstantArray ;
15+ use crate :: execution:: { kernel, BatchKernelRef , BindCtx } ;
1716use crate :: serde:: ArrayChildren ;
1817use crate :: stats:: { ArrayStats , StatsSetRef } ;
1918use crate :: vtable:: {
2019 ArrayVTable , NotSupported , OperatorVTable , SerdeVTable , VTable , VisitorVTable ,
2120} ;
2221use crate :: {
23- Array , ArrayBufferVisitor , ArrayChildVisitor , ArrayEq , ArrayHash , ArrayRef ,
24- DeserializeMetadata , EmptyMetadata , EncodingId , EncodingRef , Precision , vtable ,
22+ vtable , Array , ArrayBufferVisitor , ArrayChildVisitor , ArrayEq , ArrayHash ,
23+ ArrayRef , DeserializeMetadata , EmptyMetadata , EncodingId , EncodingRef , IntoArray , Precision ,
2524} ;
2625
27- /// The set of operators supported by a logical array.
26+ /// The set of operators supported by an arithmetic array.
2827#[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash , Enum ) ]
29- pub enum LogicalOperator {
30- /// Logical AND
31- And ,
32- /// Logical AND with Kleene logic
33- AndKleene ,
34- /// Logical OR
35- Or ,
36- /// Logical OR with Kleene logic
37- OrKleene ,
38- /// Logical AND NOT
39- AndNot ,
28+ pub enum ArithmeticOperator {
29+ /// Addition
30+ Add ,
31+ /// Subtraction
32+ Sub ,
33+ /// Multiplication
34+ Mul ,
35+ /// Division
36+ Div ,
4037}
4138
42- vtable ! ( Logical ) ;
39+ vtable ! ( Arithmetic ) ;
4340
4441#[ derive( Debug , Clone ) ]
45- pub struct LogicalArray {
42+ pub struct ArithmeticArray {
4643 encoding : EncodingRef ,
4744 lhs : ArrayRef ,
4845 rhs : ArrayRef ,
4946 stats : ArrayStats ,
5047}
5148
52- impl LogicalArray {
49+ impl ArithmeticArray {
5350 /// Create a new logical array.
54- pub fn new ( lhs : ArrayRef , rhs : ArrayRef , operator : LogicalOperator ) -> Self {
51+ pub fn new ( lhs : ArrayRef , rhs : ArrayRef , operator : ArithmeticOperator ) -> Self {
5552 assert_eq ! (
5653 lhs. len( ) ,
5754 rhs. len( ) ,
58- "Logical arrays require lhs and rhs to have the same length"
55+ "Arithmetic arrays require lhs and rhs to have the same length"
5956 ) ;
6057
6158 // TODO(ngates): should we automatically cast non-null to nullable if required?
62- assert ! ( matches!( lhs. dtype( ) , DType :: Bool ( _ ) ) ) ;
59+ assert ! ( matches!( lhs. dtype( ) , DType :: Primitive ( .. ) ) ) ;
6360 assert_eq ! ( lhs. dtype( ) , rhs. dtype( ) ) ;
6461
6562 Self {
@@ -71,29 +68,29 @@ impl LogicalArray {
7168 }
7269
7370 /// Returns the operator of this logical array.
74- pub fn operator ( & self ) -> LogicalOperator {
75- self . encoding . as_ :: < LogicalVTable > ( ) . operator
71+ pub fn operator ( & self ) -> ArithmeticOperator {
72+ self . encoding . as_ :: < ArithmeticVTable > ( ) . operator
7673 }
7774}
7875
7976#[ derive( Debug , Clone ) ]
80- pub struct LogicalEncoding {
77+ pub struct ArithmeticEncoding {
8178 // We include the operator in the encoding so each operator is a different encoding ID.
8279 // This makes it easier for plugins to construct expressions and perform pushdown
8380 // optimizations.
84- operator : LogicalOperator ,
81+ operator : ArithmeticOperator ,
8582}
8683
8784#[ allow( clippy:: mem_forget) ]
88- static ENCODINGS : LazyLock < EnumMap < LogicalOperator , EncodingRef > > = LazyLock :: new ( || {
85+ static ENCODINGS : LazyLock < EnumMap < ArithmeticOperator , EncodingRef > > = LazyLock :: new ( || {
8986 enum_map ! {
90- operator => LogicalEncoding { operator } . to_encoding( ) ,
87+ operator => ArithmeticEncoding { operator } . to_encoding( ) ,
9188 }
9289} ) ;
9390
94- impl VTable for LogicalVTable {
95- type Array = LogicalArray ;
96- type Encoding = LogicalEncoding ;
91+ impl VTable for ArithmeticVTable {
92+ type Array = ArithmeticArray ;
93+ type Encoding = ArithmeticEncoding ;
9794 type ArrayVTable = Self ;
9895 type CanonicalVTable = NotSupported ;
9996 type OperationsVTable = NotSupported ;
@@ -106,11 +103,10 @@ impl VTable for LogicalVTable {
106103
107104 fn id ( encoding : & Self :: Encoding ) -> EncodingId {
108105 match encoding. operator {
109- LogicalOperator :: And => EncodingId :: from ( "vortex.and" ) ,
110- LogicalOperator :: AndKleene => EncodingId :: from ( "vortex.and_kleene" ) ,
111- LogicalOperator :: Or => EncodingId :: from ( "vortex.or" ) ,
112- LogicalOperator :: OrKleene => EncodingId :: from ( "vortex.or_kleene" ) ,
113- LogicalOperator :: AndNot => EncodingId :: from ( "vortex.and_not" ) ,
106+ ArithmeticOperator :: Add => EncodingId :: from ( "vortex.add" ) ,
107+ ArithmeticOperator :: Sub => EncodingId :: from ( "vortex.sub" ) ,
108+ ArithmeticOperator :: Mul => EncodingId :: from ( "vortex.mul" ) ,
109+ ArithmeticOperator :: Div => EncodingId :: from ( "vortex.div" ) ,
114110 }
115111 }
116112
@@ -119,104 +115,187 @@ impl VTable for LogicalVTable {
119115 }
120116}
121117
122- impl ArrayVTable < LogicalVTable > for LogicalVTable {
123- fn len ( array : & LogicalArray ) -> usize {
118+ impl ArrayVTable < ArithmeticVTable > for ArithmeticVTable {
119+ fn len ( array : & ArithmeticArray ) -> usize {
124120 array. lhs . len ( )
125121 }
126122
127- fn dtype ( array : & LogicalArray ) -> & DType {
123+ fn dtype ( array : & ArithmeticArray ) -> & DType {
128124 array. lhs . dtype ( )
129125 }
130126
131- fn stats ( array : & LogicalArray ) -> StatsSetRef < ' _ > {
127+ fn stats ( array : & ArithmeticArray ) -> StatsSetRef < ' _ > {
132128 array. stats . to_ref ( array. as_ref ( ) )
133129 }
134130
135- fn array_hash < H : Hasher > ( array : & LogicalArray , state : & mut H , precision : Precision ) {
131+ fn array_hash < H : Hasher > ( array : & ArithmeticArray , state : & mut H , precision : Precision ) {
136132 array. lhs . array_hash ( state, precision) ;
137133 array. rhs . array_hash ( state, precision) ;
138134 }
139135
140- fn array_eq ( array : & LogicalArray , other : & LogicalArray , precision : Precision ) -> bool {
136+ fn array_eq ( array : & ArithmeticArray , other : & ArithmeticArray , precision : Precision ) -> bool {
141137 array. lhs . array_eq ( & other. lhs , precision) && array. rhs . array_eq ( & other. rhs , precision)
142138 }
143139}
144140
145- impl VisitorVTable < LogicalVTable > for LogicalVTable {
146- fn visit_buffers ( _array : & LogicalArray , _visitor : & mut dyn ArrayBufferVisitor ) {
141+ impl VisitorVTable < ArithmeticVTable > for ArithmeticVTable {
142+ fn visit_buffers ( _array : & ArithmeticArray , _visitor : & mut dyn ArrayBufferVisitor ) {
147143 // No buffers
148144 }
149145
150- fn visit_children ( array : & LogicalArray , visitor : & mut dyn ArrayChildVisitor ) {
146+ fn visit_children ( array : & ArithmeticArray , visitor : & mut dyn ArrayChildVisitor ) {
151147 visitor. visit_child ( "lhs" , array. lhs . as_ref ( ) ) ;
152148 visitor. visit_child ( "rhs" , array. rhs . as_ref ( ) ) ;
153149 }
154150}
155151
156- impl SerdeVTable < LogicalVTable > for LogicalVTable {
152+ impl SerdeVTable < ArithmeticVTable > for ArithmeticVTable {
157153 type Metadata = EmptyMetadata ;
158154
159- fn metadata ( _array : & LogicalArray ) -> VortexResult < Option < Self :: Metadata > > {
155+ fn metadata ( _array : & ArithmeticArray ) -> VortexResult < Option < Self :: Metadata > > {
160156 Ok ( Some ( EmptyMetadata ) )
161157 }
162158
163159 fn build (
164- encoding : & LogicalEncoding ,
160+ encoding : & ArithmeticEncoding ,
165161 dtype : & DType ,
166162 len : usize ,
167163 _metadata : & <Self :: Metadata as DeserializeMetadata >:: Output ,
168164 buffers : & [ ByteBuffer ] ,
169165 children : & dyn ArrayChildren ,
170- ) -> VortexResult < LogicalArray > {
166+ ) -> VortexResult < ArithmeticArray > {
171167 assert ! ( buffers. is_empty( ) ) ;
172- Ok ( LogicalArray :: new (
168+
169+ Ok ( ArithmeticArray :: new (
173170 children. get ( 0 , dtype, len) ?,
174171 children. get ( 1 , dtype, len) ?,
175172 encoding. operator ,
176173 ) )
177174 }
178175}
179176
180- impl OperatorVTable < LogicalVTable > for LogicalVTable {
177+ impl OperatorVTable < ArithmeticVTable > for ArithmeticVTable {
178+ fn reduce_children ( array : & ArithmeticArray ) -> VortexResult < Option < ArrayRef > > {
179+ match ( array. lhs . as_constant ( ) , array. rhs . as_constant ( ) ) {
180+ // If both sides are constant, we compute the value now.
181+ ( Some ( lhs) , Some ( rhs) ) => {
182+ let op: vortex_scalar:: NumericOperator = match array. operator ( ) {
183+ ArithmeticOperator :: Add => vortex_scalar:: NumericOperator :: Add ,
184+ ArithmeticOperator :: Sub => vortex_scalar:: NumericOperator :: Sub ,
185+ ArithmeticOperator :: Mul => vortex_scalar:: NumericOperator :: Mul ,
186+ ArithmeticOperator :: Div => vortex_scalar:: NumericOperator :: Div ,
187+ } ;
188+ let result = lhs
189+ . as_primitive ( )
190+ . checked_binary_numeric ( & rhs. as_primitive ( ) , op)
191+ . ok_or_else ( || {
192+ vortex_err ! ( "Constant arithmetic operation resulted in overflow" )
193+ } ) ?;
194+ return Ok ( Some (
195+ ConstantArray :: new ( Scalar :: from ( result) , array. len ( ) ) . into_array ( ) ,
196+ ) ) ;
197+ }
198+ // If either side is constant null, the result is constant null.
199+ ( Some ( lhs) , _) if lhs. is_null ( ) => {
200+ return Ok ( Some (
201+ ConstantArray :: new ( Scalar :: null ( array. dtype ( ) . clone ( ) ) , array. len ( ) )
202+ . into_array ( ) ,
203+ ) ) ;
204+ }
205+ ( _, Some ( rhs) ) if rhs. is_null ( ) => {
206+ return Ok ( Some (
207+ ConstantArray :: new ( Scalar :: null ( array. dtype ( ) . clone ( ) ) , array. len ( ) )
208+ . into_array ( ) ,
209+ ) ) ;
210+ }
211+ _ => { }
212+ }
213+
214+ Ok ( None )
215+ }
216+
181217 fn bind (
182- array : & LogicalArray ,
218+ array : & ArithmeticArray ,
183219 selection : Option < & ArrayRef > ,
184220 ctx : & mut dyn BindCtx ,
185221 ) -> VortexResult < BatchKernelRef > {
222+ // Optimize for constant RHS
223+ if let Some ( rhs) = array. rhs . as_constant ( ) {
224+ if rhs. is_null ( ) {
225+ // If the RHS is null, the result is always null.
226+ return Ok (
227+ ConstantArray :: new ( Scalar :: null ( array. dtype ( ) . clone ( ) ) , array. len ( ) )
228+ . into_array ( )
229+ . bind ( selection, ctx) ?,
230+ ) ;
231+ }
232+
233+ let lhs = ctx. bind ( & array. lhs , selection) ?;
234+ return match_each_native_ptype ! ( array. dtype( ) . as_ptype( ) , |T | {
235+ let rhs_value: T = rhs
236+ . as_primitive( )
237+ . typed_value:: <T >( )
238+ . vortex_expect( "Already checked for null above" ) ;
239+ Ok ( match array. operator( ) {
240+ ArithmeticOperator :: Add => arithmetic_scalar_kernel:: <Add , _>( lhs, rhs_value) ,
241+ ArithmeticOperator :: Sub => arithmetic_scalar_kernel:: <Sub , _>( lhs, rhs_value) ,
242+ ArithmeticOperator :: Mul => arithmetic_scalar_kernel:: <Mul , _>( lhs, rhs_value) ,
243+ ArithmeticOperator :: Div => arithmetic_scalar_kernel:: <Div , _>( lhs, rhs_value) ,
244+ } )
245+ } ) ;
246+ }
247+
186248 let lhs = ctx. bind ( & array. lhs , selection) ?;
187249 let rhs = ctx. bind ( & array. rhs , selection) ?;
188250
189- Ok ( match array. operator ( ) {
190- LogicalOperator :: And => logical_kernel ( lhs, rhs, |l, r| l. and ( & r) ) ,
191- LogicalOperator :: AndKleene => logical_kernel ( lhs, rhs, |l, r| l. and_kleene ( & r) ) ,
192- LogicalOperator :: Or => logical_kernel ( lhs, rhs, |l, r| l. or ( & r) ) ,
193- LogicalOperator :: OrKleene => logical_kernel ( lhs, rhs, |l, r| l. or_kleene ( & r) ) ,
194- LogicalOperator :: AndNot => logical_kernel ( lhs, rhs, |l, r| l. and_not ( & r) ) ,
251+ match_each_native_ptype ! ( array. dtype( ) . as_ptype( ) , |T | {
252+ Ok ( match array. operator( ) {
253+ ArithmeticOperator :: Add => arithmetic_kernel:: <Add , T >( lhs, rhs) ,
254+ ArithmeticOperator :: Sub => arithmetic_kernel:: <Sub , T >( lhs, rhs) ,
255+ ArithmeticOperator :: Mul => arithmetic_kernel:: <Mul , T >( lhs, rhs) ,
256+ ArithmeticOperator :: Div => arithmetic_kernel:: <Div , T >( lhs, rhs) ,
257+ } )
195258 } )
196259 }
197260}
198261
199262/// Batch execution kernel for logical operations.
200- fn logical_kernel < O > ( lhs : BatchKernelRef , rhs : BatchKernelRef , op : O ) -> BatchKernelRef
263+ fn arithmetic_kernel < Op , T > ( lhs : BatchKernelRef , rhs : BatchKernelRef ) -> BatchKernelRef
264+ where
265+ T : NativePType ,
266+ Op : CheckedOperator < T > ,
267+ {
268+ kernel ( move || {
269+ let lhs = lhs. execute ( ) ?. into_primitive ( ) . downcast :: < T > ( ) ;
270+ let rhs = rhs. execute ( ) ?. into_primitive ( ) . downcast :: < T > ( ) ;
271+ let result = Checked :: < Op , _ > :: checked_op ( lhs, & rhs)
272+ . ok_or_else ( || vortex_err ! ( "Arithmetic operation resulted in overflow" ) ) ?;
273+ Ok ( result. into ( ) )
274+ } )
275+ }
276+
277+ fn arithmetic_scalar_kernel < Op , T > ( lhs : BatchKernelRef , rhs : T ) -> BatchKernelRef
201278where
202- O : Fn ( BoolVector , BoolVector ) -> BoolVector + Send + ' static ,
279+ T : NativePType + TryFrom < PValue > ,
280+ Op : CheckedOperator < T > ,
203281{
204282 kernel ( move || {
205- let lhs = lhs. execute ( ) ?. into_bool ( ) ;
206- let rhs = rhs. execute ( ) ?. into_bool ( ) ;
207- Ok ( op ( lhs, rhs) . into ( ) )
283+ let lhs = lhs. execute ( ) ?. into_primitive ( ) . downcast :: < T > ( ) ;
284+ let result = Checked :: < Op , _ > :: checked_op ( lhs, & rhs)
285+ . ok_or_else ( || vortex_err ! ( "Arithmetic operation resulted in overflow" ) ) ?;
286+ Ok ( result. into ( ) )
208287 } )
209288}
210289
211290#[ cfg( test) ]
212291mod tests {
213292 use vortex_buffer:: bitbuffer;
214293
215- use crate :: compute:: arrays:: logical:: { LogicalArray , LogicalOperator } ;
294+ use crate :: compute:: arrays:: logical:: ArithmeticOperator ;
216295 use crate :: { ArrayOperator , ArrayRef , IntoArray } ;
217296
218297 fn and_ ( lhs : ArrayRef , rhs : ArrayRef ) -> ArrayRef {
219- LogicalArray :: new ( lhs, rhs, LogicalOperator :: And ) . into_array ( )
298+ ArithmeticArray :: new ( lhs, rhs, ArithmeticOperator :: And ) . into_array ( )
220299 }
221300
222301 #[ test]
0 commit comments