Skip to content

Commit b690a52

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 5c854ee commit b690a52

File tree

9 files changed

+97
-201
lines changed

9 files changed

+97
-201
lines changed

vortex-array/src/scalar_fns/binary/mod.rs

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,19 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use prost::Message;
5-
use vortex_compute::arithmetic_op;
6-
use vortex_compute::checked_arithmetic_op;
7-
use vortex_compute::compare_op;
5+
use vortex_compute::arithmetic::Add;
6+
use vortex_compute::arithmetic::Arithmetic;
7+
use vortex_compute::arithmetic::CheckedArithmetic;
8+
use vortex_compute::arithmetic::Div;
9+
use vortex_compute::arithmetic::Mul;
10+
use vortex_compute::arithmetic::Sub;
11+
use vortex_compute::comparison::Compare;
12+
use vortex_compute::comparison::Equal;
13+
use vortex_compute::comparison::GreaterThan;
14+
use vortex_compute::comparison::GreaterThanOrEqual;
15+
use vortex_compute::comparison::LessThan;
16+
use vortex_compute::comparison::LessThanOrEqual;
17+
use vortex_compute::comparison::NotEqual;
818
use vortex_compute::logical::KleeneAnd;
919
use vortex_compute::logical::KleeneOr;
1020
use vortex_compute::logical::LogicalOp;
@@ -17,7 +27,6 @@ use vortex_vector::BoolDatum;
1727
use vortex_vector::Datum;
1828
use vortex_vector::PrimitiveDatum;
1929

20-
use crate::compute;
2130
use crate::expr::ChildName;
2231
use crate::expr::Operator;
2332
use crate::expr::functions::ArgName;
@@ -45,7 +54,7 @@ impl VTable for BinaryFn {
4554
}
4655

4756
fn arity(&self, _options: &Operator) -> Arity {
48-
Arity::Fixed(2)
57+
Arity::Exact(2)
4958
}
5059

5160
fn null_handling(&self, options: &Operator) -> NullHandling {
@@ -85,62 +94,56 @@ impl VTable for BinaryFn {
8594
let lhs: Datum = args.input_datums(0).clone();
8695
let rhs: Datum = args.input_datums(1).clone();
8796

88-
if op.is_arithmetic() {
89-
execute_arithmetic_primitive(&lhs.into_primitive(), &rhs.into_primitive(), *op)
90-
} else if let Some(comp) = op.maybe_cmp_operator() {
91-
let result = compare_op!(
92-
comp,
93-
lhs,
94-
rhs,
95-
compute::Operator::Eq,
96-
compute::Operator::NotEq,
97-
compute::Operator::Lt,
98-
compute::Operator::Lte,
99-
compute::Operator::Gt,
100-
compute::Operator::Gte
101-
);
102-
Ok(result.into())
103-
} else if matches!(op, Operator::And) {
104-
Ok(<BoolDatum as LogicalOp<KleeneAnd>>::op(lhs.into_bool(), rhs.into_bool()).into())
105-
} else if matches!(op, Operator::Or) {
106-
Ok(<BoolDatum as LogicalOp<KleeneOr>>::op(lhs.into_bool(), rhs.into_bool()).into())
107-
} else {
108-
unreachable!("unknown operator type")
97+
match op {
98+
Operator::Eq => Ok(Compare::<Equal>::compare(lhs, rhs).into()),
99+
Operator::NotEq => Ok(Compare::<NotEqual>::compare(lhs, rhs).into()),
100+
Operator::Lt => Ok(Compare::<LessThan>::compare(lhs, rhs).into()),
101+
Operator::Lte => Ok(Compare::<LessThanOrEqual>::compare(lhs, rhs).into()),
102+
Operator::Gt => Ok(Compare::<GreaterThan>::compare(lhs, rhs).into()),
103+
Operator::Gte => Ok(Compare::<GreaterThanOrEqual>::compare(lhs, rhs).into()),
104+
Operator::And => Ok(<BoolDatum as LogicalOp<KleeneAnd>>::op(
105+
lhs.into_bool(),
106+
rhs.into_bool(),
107+
)
108+
.into()),
109+
Operator::Or => {
110+
Ok(<BoolDatum as LogicalOp<KleeneOr>>::op(lhs.into_bool(), rhs.into_bool()).into())
111+
}
112+
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => {
113+
execute_arithmetic_primitive(lhs.into_primitive(), rhs.into_primitive(), *op)
114+
}
109115
}
110116
}
111117
}
112118

113119
fn execute_arithmetic_primitive(
114-
lhs: &PrimitiveDatum,
115-
rhs: &PrimitiveDatum,
120+
lhs: PrimitiveDatum,
121+
rhs: PrimitiveDatum,
116122
op: Operator,
117123
) -> VortexResult<Datum> {
118124
// Float arithmetic - no overflow checking needed
119125
if lhs.ptype().is_float() && lhs.ptype() == rhs.ptype() {
120-
let result = arithmetic_op!(
121-
op,
122-
lhs,
123-
rhs,
124-
Operator::Add,
125-
Operator::Sub,
126-
Operator::Mul,
127-
Operator::Div
128-
);
126+
let result: PrimitiveDatum = match op {
127+
Operator::Add => Arithmetic::<Add>::eval(lhs, rhs),
128+
Operator::Sub => Arithmetic::<Sub>::eval(lhs, rhs),
129+
Operator::Mul => Arithmetic::<Mul>::eval(lhs, rhs),
130+
Operator::Div => Arithmetic::<Div>::eval(lhs, rhs),
131+
_ => unreachable!("Not an arithmetic operator"),
132+
};
129133
return Ok(result.into());
130134
}
131135

132136
// Integer arithmetic - use checked operations
133-
checked_arithmetic_op!(
134-
op,
135-
lhs,
136-
rhs,
137-
Operator::Add,
138-
Operator::Sub,
139-
Operator::Mul,
140-
Operator::Div
141-
)
142-
.map(|d| d.into())
143-
.ok_or_else(|| vortex_err!("Arithmetic overflow/underflow or type mismatch"))
137+
let result: Option<PrimitiveDatum> = match op {
138+
Operator::Add => CheckedArithmetic::<Add>::checked_eval(lhs, rhs),
139+
Operator::Sub => CheckedArithmetic::<Sub>::checked_eval(lhs, rhs),
140+
Operator::Mul => CheckedArithmetic::<Mul>::checked_eval(lhs, rhs),
141+
Operator::Div => CheckedArithmetic::<Div>::checked_eval(lhs, rhs),
142+
_ => unreachable!("Not an arithmetic operator"),
143+
};
144+
result
145+
.map(|d| d.into())
146+
.ok_or_else(|| vortex_err!("Arithmetic overflow/underflow or type mismatch"))
144147
}
145148

146149
#[cfg(test)]

vortex-compute/src/arithmetic/datum.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ use crate::arithmetic::CheckedArithmetic;
1313

1414
impl<Op> CheckedArithmetic<Op> for PrimitiveDatum
1515
where
16-
PrimitiveScalar: CheckedArithmetic<Op, Output = PrimitiveScalar>,
17-
PrimitiveVector: CheckedArithmetic<Op, Output = PrimitiveVector>,
16+
for<'a> &'a PrimitiveScalar: CheckedArithmetic<Op, Output = PrimitiveScalar>,
17+
for<'a> PrimitiveVector: CheckedArithmetic<Op, &'a PrimitiveVector, Output = PrimitiveVector>,
1818
{
1919
type Output = PrimitiveDatum;
2020

2121
fn checked_eval(self, rhs: PrimitiveDatum) -> Option<Self::Output> {
2222
match (self, rhs) {
2323
(PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Scalar(sc2)) => {
24-
sc1.checked_eval(sc2).map(PrimitiveDatum::Scalar)
24+
(&sc1).checked_eval(&sc2).map(PrimitiveDatum::Scalar)
2525
}
2626
(PrimitiveDatum::Vector(vec1), PrimitiveDatum::Vector(vec2)) => {
27-
vec1.checked_eval(vec2).map(PrimitiveDatum::Vector)
27+
vec1.checked_eval(&vec2).map(PrimitiveDatum::Vector)
2828
}
2929
(PrimitiveDatum::Vector(vec1), PrimitiveDatum::Scalar(sc2)) => {
3030
let len = vec1.len();
31-
vec1.checked_eval(sc2.repeat(len).freeze().into_primitive())
31+
vec1.checked_eval(&sc2.repeat(len).freeze().into_primitive())
3232
.map(PrimitiveDatum::Vector)
3333
}
3434
(PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Vector(vec2)) => {
3535
let len = vec2.len();
3636
sc1.repeat(len)
3737
.freeze()
3838
.into_primitive()
39-
.checked_eval(vec2)
39+
.checked_eval(&vec2)
4040
.map(PrimitiveDatum::Vector)
4141
}
4242
}

vortex-compute/src/arithmetic/mod.rs

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -167,53 +167,3 @@ macro_rules! impl_float {
167167
impl_float!(f16);
168168
impl_float!(f32);
169169
impl_float!(f64);
170-
171-
/// Dispatches a checked arithmetic operation based on a runtime operator value.
172-
///
173-
/// This macro allows you to call `CheckedArithmetic::<Op, _>::checked_eval(lhs, rhs)` where `Op`
174-
/// is determined at runtime from an operator enum variant.
175-
#[macro_export]
176-
macro_rules! checked_arithmetic_op {
177-
($op:expr, $lhs:expr, $rhs:expr, $Add:pat, $Sub:pat, $Mul:pat, $Div:pat) => {
178-
match $op {
179-
$Add => {
180-
$crate::arithmetic::CheckedArithmetic::<$crate::arithmetic::Add, _>::checked_eval(
181-
$lhs, $rhs,
182-
)
183-
}
184-
$Sub => {
185-
$crate::arithmetic::CheckedArithmetic::<$crate::arithmetic::Sub, _>::checked_eval(
186-
$lhs, $rhs,
187-
)
188-
}
189-
$Mul => {
190-
$crate::arithmetic::CheckedArithmetic::<$crate::arithmetic::Mul, _>::checked_eval(
191-
$lhs, $rhs,
192-
)
193-
}
194-
$Div => {
195-
$crate::arithmetic::CheckedArithmetic::<$crate::arithmetic::Div, _>::checked_eval(
196-
$lhs, $rhs,
197-
)
198-
}
199-
_ => unreachable!("Not an arithmetic operator"),
200-
}
201-
};
202-
}
203-
204-
/// Dispatches an arithmetic operation based on a runtime operator value.
205-
///
206-
/// This macro allows you to call `Arithmetic::<Op, _>::eval(lhs, rhs)` where `Op`
207-
/// is determined at runtime from an operator enum variant.
208-
#[macro_export]
209-
macro_rules! arithmetic_op {
210-
($op:expr, $lhs:expr, $rhs:expr, $Add:pat, $Sub:pat, $Mul:pat, $Div:pat) => {
211-
match $op {
212-
$Add => $crate::arithmetic::Arithmetic::<$crate::arithmetic::Add, _>::eval($lhs, $rhs),
213-
$Sub => $crate::arithmetic::Arithmetic::<$crate::arithmetic::Sub, _>::eval($lhs, $rhs),
214-
$Mul => $crate::arithmetic::Arithmetic::<$crate::arithmetic::Mul, _>::eval($lhs, $rhs),
215-
$Div => $crate::arithmetic::Arithmetic::<$crate::arithmetic::Div, _>::eval($lhs, $rhs),
216-
_ => unreachable!("Not an arithmetic operator"),
217-
}
218-
};
219-
}

vortex-compute/src/arithmetic/primitive_scalar.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use vortex_vector::primitive::PrimitiveVector;
1616
use crate::arithmetic::Arithmetic;
1717
use crate::arithmetic::CheckedArithmetic;
1818

19-
impl<Op> CheckedArithmetic<Op, &PrimitiveScalar> for &PrimitiveScalar
19+
impl<Op> CheckedArithmetic<Op> for &PrimitiveScalar
2020
where
2121
for<'a> &'a PScalar<i8>: CheckedArithmetic<Op, &'a PScalar<i8>, Output = PScalar<i8>>,
2222
for<'a> &'a PScalar<i16>: CheckedArithmetic<Op, &'a PScalar<i16>, Output = PScalar<i16>>,
@@ -31,9 +31,9 @@ where
3131

3232
fn checked_eval(self, rhs: &PrimitiveScalar) -> Option<Self::Output> {
3333
match_each_integer_pscalar_pair!(
34-
(self, rhs),
34+
(&self, &rhs),
3535
|l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
36-
{ vortex_panic!("cannot compare primitive scalar of different types") } // Type mismatch or float types
36+
{ vortex_panic!("cannot compare primitive scalar of different types") }
3737
)
3838
}
3939
}

vortex-compute/src/arithmetic/primitive_vector.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,22 @@ use vortex_vector::primitive::PrimitiveVector;
1515
use crate::arithmetic::Arithmetic;
1616
use crate::arithmetic::CheckedArithmetic;
1717

18-
impl<Op> CheckedArithmetic<Op, &PrimitiveVector> for &PrimitiveVector
18+
impl<Op> CheckedArithmetic<Op, &PrimitiveVector> for PrimitiveVector
1919
where
20-
for<'a> &'a PVector<i8>: CheckedArithmetic<Op, &'a PVector<i8>, Output = PVector<i8>>,
21-
for<'a> &'a PVector<i16>: CheckedArithmetic<Op, &'a PVector<i16>, Output = PVector<i16>>,
22-
for<'a> &'a PVector<i32>: CheckedArithmetic<Op, &'a PVector<i32>, Output = PVector<i32>>,
23-
for<'a> &'a PVector<i64>: CheckedArithmetic<Op, &'a PVector<i64>, Output = PVector<i64>>,
24-
for<'a> &'a PVector<u8>: CheckedArithmetic<Op, &'a PVector<u8>, Output = PVector<u8>>,
25-
for<'a> &'a PVector<u16>: CheckedArithmetic<Op, &'a PVector<u16>, Output = PVector<u16>>,
26-
for<'a> &'a PVector<u32>: CheckedArithmetic<Op, &'a PVector<u32>, Output = PVector<u32>>,
27-
for<'a> &'a PVector<u64>: CheckedArithmetic<Op, &'a PVector<u64>, Output = PVector<u64>>,
20+
for<'a> PVector<i8>: CheckedArithmetic<Op, &'a PVector<i8>, Output = PVector<i8>>,
21+
for<'a> PVector<i16>: CheckedArithmetic<Op, &'a PVector<i16>, Output = PVector<i16>>,
22+
for<'a> PVector<i32>: CheckedArithmetic<Op, &'a PVector<i32>, Output = PVector<i32>>,
23+
for<'a> PVector<i64>: CheckedArithmetic<Op, &'a PVector<i64>, Output = PVector<i64>>,
24+
for<'a> PVector<u8>: CheckedArithmetic<Op, &'a PVector<u8>, Output = PVector<u8>>,
25+
for<'a> PVector<u16>: CheckedArithmetic<Op, &'a PVector<u16>, Output = PVector<u16>>,
26+
for<'a> PVector<u32>: CheckedArithmetic<Op, &'a PVector<u32>, Output = PVector<u32>>,
27+
for<'a> PVector<u64>: CheckedArithmetic<Op, &'a PVector<u64>, Output = PVector<u64>>,
2828
{
2929
type Output = PrimitiveVector;
3030

3131
fn checked_eval(self, rhs: &PrimitiveVector) -> Option<Self::Output> {
3232
match_each_integer_pvector_pair!(
33-
(self, rhs),
33+
(self, &rhs),
3434
|l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
3535
{ vortex_panic!("dont use checked arithmetic for floats") }
3636
)
@@ -117,7 +117,7 @@ mod tests {
117117
.freeze()
118118
.into();
119119

120-
let result = CheckedArithmetic::<Add, _>::checked_eval(&left, &right).unwrap();
120+
let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
121121
if let PrimitiveVector::I32(v) = result {
122122
assert_eq!(v.scalar_at(0).value(), Some(11));
123123
assert_eq!(v.scalar_at(1).value(), Some(22));

vortex-compute/src/comparison/mod.rs

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -131,56 +131,3 @@ macro_rules! impl_float {
131131
impl_float!(f16);
132132
impl_float!(f32);
133133
impl_float!(f64);
134-
135-
/// Dispatches a comparison operation based on a runtime operator value.
136-
///
137-
/// This macro allows you to call `Compare::<Op>::compare(lhs, rhs)` where `Op` is determined
138-
/// at runtime from an operator enum variant.
139-
///
140-
/// # Arguments
141-
///
142-
/// * `$op` - An expression that evaluates to an operator enum (e.g., `Operator::Eq`)
143-
/// * `$lhs` - The left-hand side operand
144-
/// * `$rhs` - The right-hand side operand
145-
/// * `$Eq`, `$NotEq`, etc. - The enum variants to match against
146-
///
147-
/// # Example
148-
///
149-
/// ```ignore
150-
/// use vortex_compute::compare_op;
151-
/// use vortex_compute::comparison::Compare;
152-
///
153-
/// let result = compare_op!(
154-
/// op,
155-
/// lhs,
156-
/// rhs,
157-
/// Operator::Eq,
158-
/// Operator::NotEq,
159-
/// Operator::Lt,
160-
/// Operator::Lte,
161-
/// Operator::Gt,
162-
/// Operator::Gte
163-
/// );
164-
/// ```
165-
#[macro_export]
166-
macro_rules! compare_op {
167-
($op:expr, $lhs:expr, $rhs:expr, $Eq:pat, $NotEq:pat, $Lt:pat, $Lte:pat, $Gt:pat, $Gte:pat) => {
168-
match $op {
169-
$Eq => $crate::comparison::Compare::<$crate::comparison::Equal>::compare($lhs, $rhs),
170-
$NotEq => {
171-
$crate::comparison::Compare::<$crate::comparison::NotEqual>::compare($lhs, $rhs)
172-
}
173-
$Lt => $crate::comparison::Compare::<$crate::comparison::LessThan>::compare($lhs, $rhs),
174-
$Lte => $crate::comparison::Compare::<$crate::comparison::LessThanOrEqual>::compare(
175-
$lhs, $rhs,
176-
),
177-
$Gt => {
178-
$crate::comparison::Compare::<$crate::comparison::GreaterThan>::compare($lhs, $rhs)
179-
}
180-
$Gte => $crate::comparison::Compare::<$crate::comparison::GreaterThanOrEqual>::compare(
181-
$lhs, $rhs,
182-
),
183-
_ => unreachable!("Not a comparison operator"),
184-
}
185-
};
186-
}

vortex-compute/src/comparison/primitive_vector.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ where
3939
}
4040
}
4141

42+
impl<Op> Compare<Op> for PrimitiveVector
43+
where
44+
for<'a> &'a PrimitiveVector: Compare<Op, Output = BoolVector>,
45+
{
46+
type Output = BoolVector;
47+
48+
fn compare(self, rhs: Self) -> Self::Output {
49+
Compare::<Op>::compare(&self, &rhs)
50+
}
51+
}
52+
4253
#[cfg(test)]
4354
mod tests {
4455
use vortex_mask::Mask;

0 commit comments

Comments
 (0)