Skip to content

Commit cd30116

Browse files
committed
feat[scalar_fn]: binary ops
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 80e01eb commit cd30116

File tree

13 files changed

+1328
-0
lines changed

13 files changed

+1328
-0
lines changed

vortex-array/src/expr/exprs/binary.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ use crate::compute::sub;
2222
use crate::expr::ChildName;
2323
use crate::expr::ExprId;
2424
use crate::expr::ExpressionView;
25+
use crate::expr::ScalarFnExprExt;
2526
use crate::expr::StatsCatalog;
2627
use crate::expr::VTable;
2728
use crate::expr::VTableExt;
2829
use crate::expr::expression::Expression;
2930
use crate::expr::exprs::literal::lit;
3031
use crate::expr::exprs::operators::Operator;
3132
use crate::expr::stats::Stat;
33+
use crate::scalar_fns::binary;
3234

3335
pub struct Binary;
3436

@@ -274,6 +276,10 @@ impl VTable for Binary {
274276

275277
!infallible
276278
}
279+
280+
fn expr_v2(&self, view: &ExpressionView<Self>) -> VortexResult<Expression> {
281+
ScalarFnExprExt::try_new_expr(&binary::BinaryFn, view.operator(), view.children().clone())
282+
}
277283
}
278284

279285
impl ExpressionView<'_, Binary> {
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use prost::Message;
5+
use vortex_compute::arithmetic_op;
6+
use vortex_compute::checked_arithmetic_op;
7+
use vortex_compute::compare_op;
8+
use vortex_compute::logical::LogicalAndKleene;
9+
use vortex_compute::logical::LogicalOrKleene;
10+
use vortex_dtype::DType;
11+
use vortex_error::VortexResult;
12+
use vortex_error::vortex_bail;
13+
use vortex_error::vortex_err;
14+
use vortex_proto::expr as pb;
15+
use vortex_vector::Datum;
16+
use vortex_vector::ScalarOps;
17+
use vortex_vector::Vector;
18+
use vortex_vector::VectorMutOps;
19+
use vortex_vector::VectorOps;
20+
use vortex_vector::bool::BoolScalar;
21+
use vortex_vector::primitive::PrimitiveScalar;
22+
use vortex_vector::primitive::PrimitiveVector;
23+
24+
use crate::expr::ChildName;
25+
use crate::expr::Operator;
26+
use crate::expr::functions::ArgName;
27+
use crate::expr::functions::Arity;
28+
use crate::expr::functions::ExecutionArgs;
29+
use crate::expr::functions::FunctionId;
30+
use crate::expr::functions::NullHandling;
31+
use crate::expr::functions::VTable;
32+
33+
pub struct BinaryFn;
34+
impl VTable for BinaryFn {
35+
type Options = Operator;
36+
37+
fn id(&self) -> FunctionId {
38+
FunctionId::from("vortex.binary")
39+
}
40+
41+
fn serialize(&self, op: &Operator) -> VortexResult<Option<Vec<u8>>> {
42+
Ok(Some(pb::BinaryOpts { op: (*op).into() }.encode_to_vec()))
43+
}
44+
45+
fn deserialize(&self, bytes: &[u8]) -> VortexResult<Operator> {
46+
let opts = pb::BinaryOpts::decode(bytes)?;
47+
Operator::try_from(opts.op)
48+
}
49+
50+
fn arity(&self, _options: &Operator) -> Arity {
51+
Arity::Fixed(2)
52+
}
53+
54+
fn null_handling(&self, options: &Operator) -> NullHandling {
55+
match options {
56+
// Kleene logic for AND/OR - they have special null semantics
57+
Operator::And | Operator::Or => NullHandling::AbsorbsNull,
58+
// All other operators propagate nulls
59+
_ => NullHandling::Propagate,
60+
}
61+
}
62+
63+
fn arg_name(&self, _options: &Operator, arg_idx: usize) -> ArgName {
64+
match arg_idx {
65+
0 => ChildName::from("lhs"),
66+
1 => ChildName::from("rhs"),
67+
_ => unreachable!("Binary has only two arguments"),
68+
}
69+
}
70+
71+
fn return_dtype(&self, options: &Operator, arg_types: &[DType]) -> VortexResult<DType> {
72+
let lhs = &arg_types[0];
73+
let rhs = &arg_types[1];
74+
75+
if options.is_arithmetic() {
76+
if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
77+
return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
78+
}
79+
vortex_bail!(
80+
"incompatible types for arithmetic operation: {} {}",
81+
lhs,
82+
rhs
83+
);
84+
}
85+
86+
Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
87+
}
88+
89+
fn execute(&self, op: &Operator, args: &ExecutionArgs) -> VortexResult<Datum> {
90+
let lhs = args.input_datums(0);
91+
let rhs = args.input_datums(1);
92+
93+
match (lhs, rhs) {
94+
(Datum::Vector(lhs_vec), Datum::Vector(rhs_vec)) => {
95+
execute_vector_vector(lhs_vec, rhs_vec, *op)
96+
}
97+
(Datum::Scalar(lhs_sc), Datum::Scalar(rhs_sc)) => {
98+
execute_scalar_scalar(lhs_sc, rhs_sc, *op)
99+
}
100+
// TODO: remove repeat
101+
(Datum::Scalar(lhs_sc), Datum::Vector(rhs_vec)) => {
102+
execute_vector_vector(&lhs_sc.repeat(rhs_vec.len()).freeze(), rhs_vec, *op)
103+
}
104+
(Datum::Vector(lhs_vec), Datum::Scalar(rhs_sc)) => {
105+
execute_vector_vector(lhs_vec, &rhs_sc.repeat(lhs_vec.len()).freeze(), *op)
106+
}
107+
}
108+
}
109+
}
110+
111+
fn execute_vector_vector(lhs: &Vector, rhs: &Vector, op: Operator) -> VortexResult<Datum> {
112+
match (lhs, rhs, op) {
113+
// Logical operations (AND/OR) - only for Bool vectors
114+
(Vector::Bool(l), Vector::Bool(r), Operator::And) => {
115+
Ok(Datum::Vector(l.and_kleene(r).into()))
116+
}
117+
(Vector::Bool(l), Vector::Bool(r), Operator::Or) => {
118+
Ok(Datum::Vector(l.or_kleene(r).into()))
119+
}
120+
121+
// Comparison operations - Bool vectors
122+
(Vector::Bool(l), Vector::Bool(r), op) if op.maybe_cmp_operator().is_some() => {
123+
let result = compare_op!(
124+
op,
125+
l,
126+
r,
127+
Operator::Eq,
128+
Operator::NotEq,
129+
Operator::Lt,
130+
Operator::Lte,
131+
Operator::Gt,
132+
Operator::Gte
133+
);
134+
Ok(Datum::Vector(result.into()))
135+
}
136+
137+
// Comparison operations - Primitive vectors
138+
(Vector::Primitive(l), Vector::Primitive(r), op) if op.maybe_cmp_operator().is_some() => {
139+
let result = compare_op!(
140+
op,
141+
l,
142+
r,
143+
Operator::Eq,
144+
Operator::NotEq,
145+
Operator::Lt,
146+
Operator::Lte,
147+
Operator::Gt,
148+
Operator::Gte
149+
);
150+
Ok(Datum::Vector(result.into()))
151+
}
152+
153+
// Arithmetic operations - Primitive vectors
154+
(Vector::Primitive(l), Vector::Primitive(r), op) if op.is_arithmetic() => {
155+
execute_arithmetic_primitive(l, r, op)
156+
}
157+
158+
_ => vortex_bail!(
159+
"Binary operation {:?} not supported for vector types {:?} and {:?}",
160+
op,
161+
lhs,
162+
rhs
163+
),
164+
}
165+
}
166+
167+
fn execute_arithmetic_primitive(
168+
lhs: &PrimitiveVector,
169+
rhs: &PrimitiveVector,
170+
op: Operator,
171+
) -> VortexResult<Datum> {
172+
// Float arithmetic - no overflow checking needed
173+
if lhs.ptype().is_float() && lhs.ptype() == rhs.ptype() {
174+
let result = arithmetic_op!(
175+
op,
176+
lhs,
177+
rhs,
178+
Operator::Add,
179+
Operator::Sub,
180+
Operator::Mul,
181+
Operator::Div
182+
);
183+
return Ok(Datum::Vector(result.into()));
184+
}
185+
186+
// Integer arithmetic - use checked operations
187+
let result: Option<PrimitiveVector> = checked_arithmetic_op!(
188+
op,
189+
lhs,
190+
rhs,
191+
Operator::Add,
192+
Operator::Sub,
193+
Operator::Mul,
194+
Operator::Div
195+
);
196+
197+
match result {
198+
Some(v) => Ok(Datum::Vector(v.into())),
199+
None => Err(vortex_err!(
200+
"Arithmetic overflow/underflow or type mismatch"
201+
)),
202+
}
203+
}
204+
205+
fn execute_scalar_scalar(
206+
lhs: &vortex_vector::Scalar,
207+
rhs: &vortex_vector::Scalar,
208+
op: Operator,
209+
) -> VortexResult<Datum> {
210+
use vortex_vector::Scalar;
211+
use vortex_vector::ScalarOps;
212+
213+
// Handle null propagation for non-kleene operators
214+
if !matches!(op, Operator::And | Operator::Or) && (!lhs.is_valid() || !rhs.is_valid()) {
215+
return match op {
216+
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
217+
// Return null primitive - we'd need to know the type
218+
// For now, bail
219+
vortex_bail!("Null scalar arithmetic not yet supported")
220+
}
221+
_ => Ok(Datum::Scalar(BoolScalar::new(None).into())),
222+
};
223+
}
224+
225+
match (lhs, rhs) {
226+
(Scalar::Bool(l), Scalar::Bool(r)) => {
227+
let result: BoolScalar = match op {
228+
Operator::And => l.and_kleene(r),
229+
Operator::Or => l.or_kleene(r),
230+
op if op.maybe_cmp_operator().is_some() => compare_op!(
231+
op,
232+
l,
233+
r,
234+
Operator::Eq,
235+
Operator::NotEq,
236+
Operator::Lt,
237+
Operator::Lte,
238+
Operator::Gt,
239+
Operator::Gte
240+
),
241+
_ => vortex_bail!("Arithmetic not supported for bool scalars"),
242+
};
243+
Ok(Datum::Scalar(result.into()))
244+
}
245+
(Scalar::Primitive(l), Scalar::Primitive(r)) => execute_scalar_scalar_primitive(l, r, op),
246+
_ => vortex_bail!(
247+
"Binary operation not supported for scalar types {:?} and {:?}",
248+
lhs,
249+
rhs
250+
),
251+
}
252+
}
253+
254+
fn execute_scalar_scalar_primitive(
255+
lhs: &PrimitiveScalar,
256+
rhs: &PrimitiveScalar,
257+
op: Operator,
258+
) -> VortexResult<Datum> {
259+
if op.maybe_cmp_operator().is_some() {
260+
let result = compare_op!(
261+
op,
262+
lhs,
263+
rhs,
264+
Operator::Eq,
265+
Operator::NotEq,
266+
Operator::Lt,
267+
Operator::Lte,
268+
Operator::Gt,
269+
Operator::Gte
270+
);
271+
return Ok(Datum::Scalar(result.into()));
272+
}
273+
274+
if op.is_arithmetic() {
275+
return execute_scalar_arithmetic_primitive(lhs, rhs, op);
276+
}
277+
278+
vortex_bail!("Operation {:?} not supported for primitive scalars", op)
279+
}
280+
281+
fn execute_scalar_arithmetic_primitive(
282+
lhs: &PrimitiveScalar,
283+
rhs: &PrimitiveScalar,
284+
op: Operator,
285+
) -> VortexResult<Datum> {
286+
// Float arithmetic - no overflow checking needed
287+
if lhs.ptype().is_float() && lhs.ptype() == rhs.ptype() {
288+
let result = arithmetic_op!(
289+
op,
290+
lhs,
291+
rhs,
292+
Operator::Add,
293+
Operator::Sub,
294+
Operator::Mul,
295+
Operator::Div
296+
);
297+
return Ok(Datum::Scalar(result.into()));
298+
}
299+
300+
// Integer arithmetic - use checked operations
301+
let result: Option<PrimitiveScalar> = checked_arithmetic_op!(
302+
op,
303+
lhs,
304+
rhs,
305+
Operator::Add,
306+
Operator::Sub,
307+
Operator::Mul,
308+
Operator::Div
309+
);
310+
311+
match result {
312+
Some(v) => Ok(Datum::Scalar(v.into())),
313+
None => Err(vortex_err!(
314+
"Arithmetic overflow/underflow or type mismatch"
315+
)),
316+
}
317+
}

vortex-array/src/scalar_fns/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::arrays::ScalarFnArrayExt;
1818
use crate::expr::Expression;
1919
use crate::expr::ScalarFnExprExt;
2020

21+
pub mod binary;
2122
pub mod cast;
2223
pub mod is_null;
2324
pub mod not;

0 commit comments

Comments
 (0)