Skip to content

Commit 07ca06c

Browse files
committed
u
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 18dcc75 commit 07ca06c

File tree

1 file changed

+31
-39
lines changed

1 file changed

+31
-39
lines changed

vortex-array/src/expr/exprs/binary.rs

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33

44
use std::fmt::Formatter;
55

6-
use arrow_ord::cmp;
76
use prost::Message;
8-
use vortex_compute::arrow::IntoArrow;
9-
use vortex_compute::arrow::IntoVector;
7+
use vortex_compute::arithmetic::Add as AddOp;
8+
use vortex_compute::arithmetic::Arithmetic;
9+
use vortex_compute::arithmetic::Div as DivOp;
10+
use vortex_compute::arithmetic::Mul as MulOp;
11+
use vortex_compute::arithmetic::Sub as SubOp;
12+
use vortex_compute::comparison::Compare;
13+
use vortex_compute::comparison::Equal;
14+
use vortex_compute::comparison::GreaterThan;
15+
use vortex_compute::comparison::GreaterThanOrEqual;
16+
use vortex_compute::comparison::LessThan;
17+
use vortex_compute::comparison::LessThanOrEqual;
18+
use vortex_compute::comparison::NotEqual;
1019
use vortex_compute::logical::LogicalAndKleene;
1120
use vortex_compute::logical::LogicalOrKleene;
1221
use vortex_dtype::DType;
@@ -16,7 +25,6 @@ use vortex_error::vortex_bail;
1625
use vortex_error::vortex_err;
1726
use vortex_proto::expr as pb;
1827
use vortex_vector::Datum;
19-
use vortex_vector::VectorOps;
2028

2129
use crate::ArrayRef;
2230
use crate::compute;
@@ -136,54 +144,38 @@ impl VTable for Binary {
136144
.try_into()
137145
.map_err(|_| vortex_err!("Wrong arg count"))?;
138146

139-
match op {
147+
// Use native Vortex compute operations directly on Datums
148+
let result: Datum = match op {
149+
// Comparison operations
150+
Operator::Eq => Compare::<Equal>::compare(lhs, rhs).into(),
151+
Operator::NotEq => Compare::<NotEqual>::compare(lhs, rhs).into(),
152+
Operator::Lt => Compare::<LessThan>::compare(lhs, rhs).into(),
153+
Operator::Lte => Compare::<LessThanOrEqual>::compare(lhs, rhs).into(),
154+
Operator::Gt => Compare::<GreaterThan>::compare(lhs, rhs).into(),
155+
Operator::Gte => Compare::<GreaterThanOrEqual>::compare(lhs, rhs).into(),
156+
157+
// Logical operations
140158
Operator::And => {
141-
return Ok(LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
159+
LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into()
142160
}
143-
Operator::Or => {
144-
return Ok(LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
145-
}
146-
_ => {}
147-
}
148-
149-
let lhs = lhs.into_arrow()?;
150-
let rhs = rhs.into_arrow()?;
151-
152-
let vector = match op {
153-
Operator::Eq => cmp::eq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
154-
Operator::NotEq => cmp::neq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
155-
Operator::Gt => cmp::gt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
156-
Operator::Gte => cmp::gt_eq(lhs.as_ref(), rhs.as_ref())?
157-
.into_vector()?
158-
.into(),
159-
Operator::Lt => cmp::lt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
160-
Operator::Lte => cmp::lt_eq(lhs.as_ref(), rhs.as_ref())?
161-
.into_vector()?
162-
.into(),
161+
Operator::Or => LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into(),
163162

163+
// Arithmetic operations
164164
Operator::Add => {
165-
arrow_arith::numeric::add(lhs.as_ref(), rhs.as_ref())?.into_vector()?
165+
Arithmetic::<AddOp, _>::eval(lhs.into_primitive(), rhs.into_primitive()).into()
166166
}
167167
Operator::Sub => {
168-
arrow_arith::numeric::sub(lhs.as_ref(), rhs.as_ref())?.into_vector()?
168+
Arithmetic::<SubOp, _>::eval(lhs.into_primitive(), rhs.into_primitive()).into()
169169
}
170170
Operator::Mul => {
171-
arrow_arith::numeric::mul(lhs.as_ref(), rhs.as_ref())?.into_vector()?
171+
Arithmetic::<MulOp, _>::eval(lhs.into_primitive(), rhs.into_primitive()).into()
172172
}
173173
Operator::Div => {
174-
arrow_arith::numeric::div(lhs.as_ref(), rhs.as_ref())?.into_vector()?
175-
}
176-
Operator::And | Operator::Or => {
177-
unreachable!("Already dealt with above")
174+
Arithmetic::<DivOp, _>::eval(lhs.into_primitive(), rhs.into_primitive()).into()
178175
}
179176
};
180177

181-
// Arrow computed over scalar datums
182-
if vector.len() == 1 && args.row_count != 1 {
183-
return Ok(Datum::Scalar(vector.scalar_at(0)));
184-
}
185-
186-
Ok(Datum::Vector(vector))
178+
Ok(result)
187179
}
188180

189181
fn stat_falsification(

0 commit comments

Comments
 (0)