Skip to content

Commit a2ba01a

Browse files
authored
feat: add Trunc op (#956)
* feat: add `Trunc` op * tests
1 parent 07cba4b commit a2ba01a

File tree

17 files changed

+87
-1
lines changed

17 files changed

+87
-1
lines changed

crates/cubecl-core/src/frontend/container/line/ops.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use num_traits::{NumCast, ToPrimitive};
44

55
use crate::{
66
self as cubecl,
7-
prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub},
7+
prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
88
};
99
use crate::{
1010
frontend::{
@@ -255,6 +255,7 @@ impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
255255
impl<P: CubePrimitive + Round> Round for Line<P> {}
256256
impl<P: CubePrimitive + Floor> Floor for Line<P> {}
257257
impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
258+
impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
258259
impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
259260
impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
260261
impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}

crates/cubecl-core/src/frontend/element/float.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub trait Float:
3232
+ Round
3333
+ Floor
3434
+ Ceil
35+
+ Trunc
3536
+ Erf
3637
+ Recip
3738
+ Magnitude

crates/cubecl-core/src/frontend/element/float/typemap.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ impl<const POS: u8> Sqrt for ElemExpand<POS> {}
254254
impl<const POS: u8> Round for ElemExpand<POS> {}
255255
impl<const POS: u8> Floor for ElemExpand<POS> {}
256256
impl<const POS: u8> Ceil for ElemExpand<POS> {}
257+
impl<const POS: u8> Trunc for ElemExpand<POS> {}
257258
impl<const POS: u8> IsNan for ElemExpand<POS> {}
258259
impl<const POS: u8> IsInf for ElemExpand<POS> {}
259260

crates/cubecl-core/src/frontend/operation/unary.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,18 @@ impl_unary_func!(
235235
f32,
236236
f64
237237
);
238+
impl_unary_func!(
239+
Trunc,
240+
trunc,
241+
__expand_trunc,
242+
Arithmetic::Trunc,
243+
f16,
244+
bf16,
245+
flex32,
246+
tf32,
247+
f32,
248+
f64
249+
);
238250
impl_unary_func!(
239251
Erf,
240252
erf,

crates/cubecl-core/src/runtime_tests/unary.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ test_unary_impl!(
305305
]
306306
);
307307

308+
test_unary_impl!(
309+
test_trunc,
310+
F,
311+
F::trunc,
312+
[{
313+
input_vectorization: 1,
314+
out_vectorization: 1,
315+
input: as_type![F: -1.2, -1., -0., 0.],
316+
expected: as_type![F: -1., -1., -0., 0.]
317+
},
318+
{
319+
input_vectorization: 2,
320+
out_vectorization: 2,
321+
input: as_type![F: f32::NAN, 1., 1.2, 1.9],
322+
expected: as_type![F: f32::NAN, 1., 1., 1.0]
323+
},{
324+
input_vectorization: 4,
325+
out_vectorization: 4,
326+
input: as_type![F: -0.9, 0.2, f32::NAN, 1.99],
327+
expected: as_type![F: -0., 0., f32::NAN, 1.]
328+
}]
329+
);
330+
308331
test_unary_impl_fixed!(
309332
test_is_nan,
310333
F,
@@ -479,6 +502,7 @@ macro_rules! testgen_unary {
479502
add_test!(test_normalize);
480503
add_test!(test_magnitude);
481504
add_test!(test_abs);
505+
add_test!(test_trunc);
482506
add_test!(test_is_nan);
483507
add_test!(test_is_inf);
484508
}

crates/cubecl-cpp/src/shared/base.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,9 @@ impl<D: Dialect> CppCompiler<D> {
10121012
gpu::Arithmetic::Ceil(op) => {
10131013
instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
10141014
}
1015+
gpu::Arithmetic::Trunc(op) => {
1016+
instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
1017+
}
10151018
gpu::Arithmetic::Remainder(op) => {
10161019
instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
10171020
}

crates/cubecl-cpp/src/shared/instruction.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ pub enum Instruction<D: Dialect> {
200200
},
201201
Round(UnaryInstruction<D>),
202202
Ceil(UnaryInstruction<D>),
203+
Trunc(UnaryInstruction<D>),
203204
Floor(UnaryInstruction<D>),
204205
Warp(WarpInstruction<D>),
205206
Wmma(WmmaInstruction<D>),
@@ -538,6 +539,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
538539
Instruction::ThreadFence => f.write_str("__threadfence();\n"),
539540
Instruction::Round(it) => Round::format(f, &it.input, &it.out),
540541
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
542+
Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
541543
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
542544
Instruction::SliceLength { input, out } => {
543545
let out = out.fmt_left();

crates/cubecl-cpp/src/shared/unary.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ function!(Sin, "sin");
154154
function!(Sqrt, "sqrt");
155155
function!(Exp, "exp");
156156
function!(Ceil, "ceil");
157+
function!(Trunc, "trunc");
157158
function!(Floor, "floor");
158159
function!(Round, "rint");
159160

crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ impl<'a> Visitor<'a> {
4040
));
4141
self.insert_variable(out, result);
4242
}
43+
Arithmetic::Trunc(trunc) => {
44+
let value = self.get_variable(trunc.input);
45+
let result = self.append_operation_with_result(llvm_ods::intr_trunc(
46+
self.context,
47+
value,
48+
self.location,
49+
));
50+
self.insert_variable(out, result);
51+
}
4352
Arithmetic::Clamp(clamp) => {
4453
let value = self.get_variable(clamp.input);
4554
let mut min = self.get_variable(clamp.min_value);

crates/cubecl-ir/src/arithmetic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub enum Arithmetic {
3232
Round(UnaryOperator),
3333
Floor(UnaryOperator),
3434
Ceil(UnaryOperator),
35+
Trunc(UnaryOperator),
3536
Erf(UnaryOperator),
3637
Recip(UnaryOperator),
3738
Clamp(ClampOperator),
@@ -73,6 +74,7 @@ impl Display for Arithmetic {
7374
Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
7475
Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
7576
Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
77+
Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input),
7678
Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
7779
Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
7880
Arithmetic::Clamp(op) => {

0 commit comments

Comments
 (0)