Skip to content

Commit 80492ef

Browse files
committed
Add arithmetic buffer compute
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 110dd75 commit 80492ef

File tree

2 files changed

+258
-239
lines changed

2 files changed

+258
-239
lines changed

vortex-compute/src/arithmetic/buffer.rs

Lines changed: 105 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,126 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_buffer::{Buffer, BufferMut};
5+
use vortex_dtype::half::f16;
56

67
use crate::arithmetic::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};
78

8-
impl<T: Copy + num_traits::CheckedAdd + num_traits::Zero> CheckedAdd<&Buffer<T>> for Buffer<T> {
9-
type Output = Self;
9+
macro_rules! checked_op {
10+
($Trait:ident, $name:ident, $T:ty, $op:expr) => {
11+
impl $Trait<&Buffer<$T>> for &Buffer<$T> {
12+
type Output = Buffer<$T>;
1013

11-
fn checked_add(self, other: &Buffer<T>) -> Option<Self::Output> {
12-
buffer_op_inplace(self, other, |a, b| a.checked_add(b))
13-
}
14-
}
14+
fn $name(self, other: &Buffer<$T>) -> Option<Self::Output> {
15+
buffer_op(self, other, $op)
16+
}
17+
}
1518

16-
impl<T: Copy + num_traits::CheckedAdd + num_traits::Zero> CheckedAdd<&T> for Buffer<T> {
17-
type Output = Self;
19+
impl $Trait<&$T> for &Buffer<$T> {
20+
type Output = Buffer<$T>;
1821

19-
fn checked_add(self, other: &T) -> Option<Self::Output> {
20-
buffer_op_inplace_scalar(self, other, |a, b| a.checked_add(b))
21-
}
22-
}
22+
fn $name(self, other: &$T) -> Option<Self::Output> {
23+
buffer_op_scalar(self, other, $op)
24+
}
25+
}
2326

24-
impl<T: Copy + num_traits::CheckedSub + num_traits::Zero> CheckedSub<&Buffer<T>> for Buffer<T> {
25-
type Output = Self;
27+
impl $Trait<&Buffer<$T>> for Buffer<$T> {
28+
type Output = Buffer<$T>;
2629

27-
fn checked_sub(self, other: &Buffer<T>) -> Option<Self::Output> {
28-
buffer_op_inplace(self, other, |a, b| a.checked_sub(b))
29-
}
30-
}
30+
fn $name(self, other: &Buffer<$T>) -> Option<Self::Output> {
31+
buffer_op_inplace(self, other, $op)
32+
}
33+
}
3134

32-
impl<T: Copy + num_traits::CheckedSub + num_traits::Zero> CheckedSub<&T> for Buffer<T> {
33-
type Output = Self;
35+
impl $Trait<&$T> for Buffer<$T> {
36+
type Output = Buffer<$T>;
3437

35-
fn checked_sub(self, other: &T) -> Option<Self::Output> {
36-
buffer_op_inplace_scalar(self, other, |a, b| a.checked_sub(b))
37-
}
38-
}
38+
fn $name(self, other: &$T) -> Option<Self::Output> {
39+
buffer_op_inplace_scalar(self, other, $op)
40+
}
41+
}
3942

40-
impl<T: Copy + num_traits::CheckedMul + num_traits::Zero> CheckedMul<&Buffer<T>> for Buffer<T> {
41-
type Output = Self;
43+
impl $Trait<&Buffer<$T>> for BufferMut<$T> {
44+
type Output = Buffer<$T>;
4245

43-
fn checked_mul(self, other: &Buffer<T>) -> Option<Self::Output> {
44-
buffer_op_inplace(self, other, |a, b| a.checked_mul(b))
45-
}
46-
}
46+
fn $name(self, other: &Buffer<$T>) -> Option<Self::Output> {
47+
buffer_op_mut(self, other, $op)
48+
}
49+
}
4750

48-
impl<T: Copy + num_traits::CheckedMul + num_traits::Zero> CheckedMul<&T> for Buffer<T> {
49-
type Output = Self;
51+
impl $Trait<&$T> for BufferMut<$T> {
52+
type Output = Buffer<$T>;
5053

51-
fn checked_mul(self, other: &T) -> Option<Self::Output> {
52-
buffer_op_inplace_scalar(self, other, |a, b| a.checked_mul(b))
53-
}
54+
fn $name(self, other: &$T) -> Option<Self::Output> {
55+
buffer_op_mut_scalar(self, other, $op)
56+
}
57+
}
58+
};
5459
}
5560

56-
impl<T: Copy + num_traits::CheckedDiv + num_traits::Zero> CheckedDiv<&Buffer<T>> for Buffer<T> {
57-
type Output = Self;
58-
59-
fn checked_div(self, other: &Buffer<T>) -> Option<Self::Output> {
60-
buffer_op_inplace(self, other, |a, b| a.checked_div(b))
61-
}
61+
/// For integers, we can delegate to the num_traits wrapping operations.
62+
macro_rules! integer_ops {
63+
($T:ty) => {
64+
checked_op!(
65+
CheckedAdd,
66+
checked_add,
67+
$T,
68+
num_traits::CheckedAdd::checked_add
69+
);
70+
checked_op!(
71+
CheckedSub,
72+
checked_sub,
73+
$T,
74+
num_traits::CheckedSub::checked_sub
75+
);
76+
checked_op!(
77+
CheckedMul,
78+
checked_mul,
79+
$T,
80+
num_traits::CheckedMul::checked_mul
81+
);
82+
checked_op!(
83+
CheckedDiv,
84+
checked_div,
85+
$T,
86+
num_traits::CheckedDiv::checked_div
87+
);
88+
};
6289
}
6390

64-
impl<T: Copy + num_traits::CheckedDiv + num_traits::Zero> CheckedDiv<&T> for Buffer<T> {
65-
type Output = Self;
66-
67-
fn checked_div(self, other: &T) -> Option<Self::Output> {
68-
buffer_op_inplace_scalar(self, other, |a, b| a.checked_div(b))
69-
}
91+
/// For floats, there are no checked operations. So we use regular operations that never fail.
92+
macro_rules! float_ops {
93+
($T:ty) => {
94+
checked_op!(CheckedAdd, checked_add, $T, |a, b| Some(
95+
std::ops::Add::add(a, b)
96+
));
97+
checked_op!(CheckedSub, checked_sub, $T, |a, b| Some(
98+
std::ops::Sub::sub(a, b)
99+
));
100+
checked_op!(CheckedMul, checked_mul, $T, |a, b| Some(
101+
std::ops::Mul::mul(a, b)
102+
));
103+
checked_op!(CheckedDiv, checked_div, $T, |a, b| Some(
104+
std::ops::Div::div(a, b)
105+
));
106+
};
70107
}
71108

72-
fn buffer_op_inplace<O, T>(lhs: Buffer<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
109+
integer_ops!(u8);
110+
integer_ops!(u16);
111+
integer_ops!(u32);
112+
integer_ops!(u64);
113+
integer_ops!(u128);
114+
integer_ops!(i8);
115+
integer_ops!(i16);
116+
integer_ops!(i32);
117+
integer_ops!(i64);
118+
integer_ops!(i128);
119+
120+
float_ops!(f16);
121+
float_ops!(f32);
122+
float_ops!(f64);
123+
124+
pub(super) fn buffer_op_inplace<O, T>(lhs: Buffer<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
73125
where
74126
O: Fn(&T, &T) -> Option<T>,
75127
T: Copy + num_traits::Zero,
@@ -80,7 +132,7 @@ where
80132
}
81133
}
82134

83-
fn buffer_op_mut<O, T>(lhs: BufferMut<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
135+
pub(super) fn buffer_op_mut<O, T>(lhs: BufferMut<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
84136
where
85137
O: Fn(&T, &T) -> Option<T>,
86138
T: Copy + num_traits::Zero,
@@ -108,7 +160,7 @@ where
108160
(!overflow).then_some(buffer)
109161
}
110162

111-
fn buffer_op<O, T>(lhs: &Buffer<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
163+
pub(super) fn buffer_op<O, T>(lhs: &Buffer<T>, rhs: &Buffer<T>, op: O) -> Option<Buffer<T>>
112164
where
113165
O: Fn(&T, &T) -> Option<T>,
114166
T: Copy + num_traits::Zero,
@@ -128,7 +180,7 @@ where
128180
(!overflow).then_some(buffer)
129181
}
130182

131-
fn buffer_op_inplace_scalar<O, T>(lhs: Buffer<T>, rhs: &T, op: O) -> Option<Buffer<T>>
183+
pub(super) fn buffer_op_inplace_scalar<O, T>(lhs: Buffer<T>, rhs: &T, op: O) -> Option<Buffer<T>>
132184
where
133185
O: Fn(&T, &T) -> Option<T>,
134186
T: Copy + num_traits::Zero,
@@ -139,7 +191,7 @@ where
139191
}
140192
}
141193

142-
fn buffer_op_mut_scalar<O, T>(lhs: BufferMut<T>, rhs: &T, op: O) -> Option<Buffer<T>>
194+
pub(super) fn buffer_op_mut_scalar<O, T>(lhs: BufferMut<T>, rhs: &T, op: O) -> Option<Buffer<T>>
143195
where
144196
O: Fn(&T, &T) -> Option<T>,
145197
T: Copy + num_traits::Zero,
@@ -157,7 +209,7 @@ where
157209
(!overflow).then_some(buffer)
158210
}
159211

160-
fn buffer_op_scalar<O, T>(lhs: &Buffer<T>, rhs: &T, op: O) -> Option<Buffer<T>>
212+
pub(super) fn buffer_op_scalar<O, T>(lhs: &Buffer<T>, rhs: &T, op: O) -> Option<Buffer<T>>
161213
where
162214
O: Fn(&T, &T) -> Option<T>,
163215
T: Copy + num_traits::Zero,

0 commit comments

Comments
 (0)