Skip to content

Commit c580ef0

Browse files
committed
Use rounding instructions on aarch64
1 parent 7ccb126 commit c580ef0

File tree

7 files changed

+280
-17
lines changed

7 files changed

+280
-17
lines changed

libm/src/math/arch/aarch64.rs

Lines changed: 187 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,156 @@ pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 {
3030
x
3131
}
3232

33+
pub fn ceil(mut x: f64) -> f64 {
34+
// SAFETY: `frintp` is available with neon and has no side effects.
35+
unsafe {
36+
asm!(
37+
"frintp {x:d}, {x:d}",
38+
x = inout(vreg) x,
39+
options(nomem, nostack, pure)
40+
);
41+
}
42+
x
43+
}
44+
45+
pub fn ceilf(mut x: f32) -> f32 {
46+
// SAFETY: `frintp` is available with neon and has no side effects.
47+
unsafe {
48+
asm!(
49+
"frintp {x:s}, {x:s}",
50+
x = inout(vreg) x,
51+
options(nomem, nostack, pure)
52+
);
53+
}
54+
x
55+
}
56+
57+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
58+
pub fn ceilf16(mut x: f16) -> f16 {
59+
// SAFETY: `frintp` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
60+
unsafe {
61+
asm!(
62+
"frintp {x:h}, {x:h}",
63+
x = inout(vreg) x,
64+
options(nomem, nostack, pure)
65+
);
66+
}
67+
x
68+
}
69+
70+
pub fn floor(mut x: f64) -> f64 {
71+
// SAFETY: `frintm` is available with neon and has no side effects.
72+
unsafe {
73+
asm!(
74+
"frintm {x:d}, {x:d}",
75+
x = inout(vreg) x,
76+
options(nomem, nostack, pure)
77+
);
78+
}
79+
x
80+
}
81+
82+
pub fn floorf(mut x: f32) -> f32 {
83+
// SAFETY: `frintm` is available with neon and has no side effects.
84+
unsafe {
85+
asm!(
86+
"frintm {x:s}, {x:s}",
87+
x = inout(vreg) x,
88+
options(nomem, nostack, pure)
89+
);
90+
}
91+
x
92+
}
93+
94+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
95+
pub fn floorf16(mut x: f16) -> f16 {
96+
// SAFETY: `frintm` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
97+
unsafe {
98+
asm!(
99+
"frintm {x:h}, {x:h}",
100+
x = inout(vreg) x,
101+
options(nomem, nostack, pure)
102+
);
103+
}
104+
x
105+
}
106+
33107
pub fn rint(mut x: f64) -> f64 {
108+
// SAFETY: `frintx` is available with neon and has no side effects.
109+
unsafe {
110+
asm!(
111+
"frintx {x:d}, {x:d}",
112+
x = inout(vreg) x,
113+
options(nomem, nostack, pure)
114+
);
115+
}
116+
x
117+
}
118+
119+
pub fn rintf(mut x: f32) -> f32 {
120+
// SAFETY: `frintx` is available with neon and has no side effects.
121+
unsafe {
122+
asm!(
123+
"frintx {x:s}, {x:s}",
124+
x = inout(vreg) x,
125+
options(nomem, nostack, pure)
126+
);
127+
}
128+
x
129+
}
130+
131+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
132+
pub fn rintf16(mut x: f16) -> f16 {
133+
// SAFETY: `frintx` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
134+
unsafe {
135+
asm!(
136+
"frintx {x:h}, {x:h}",
137+
x = inout(vreg) x,
138+
options(nomem, nostack, pure)
139+
);
140+
}
141+
x
142+
}
143+
144+
pub fn round(mut x: f64) -> f64 {
145+
// SAFETY: `frinta` is available with neon and has no side effects.
146+
unsafe {
147+
asm!(
148+
"frinta {x:d}, {x:d}",
149+
x = inout(vreg) x,
150+
options(nomem, nostack, pure)
151+
);
152+
}
153+
x
154+
}
155+
156+
pub fn roundf(mut x: f32) -> f32 {
157+
// SAFETY: `frinta` is available with neon and has no side effects.
158+
unsafe {
159+
asm!(
160+
"frinta {x:s}, {x:s}",
161+
x = inout(vreg) x,
162+
options(nomem, nostack, pure)
163+
);
164+
}
165+
x
166+
}
167+
168+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
169+
pub fn roundf16(mut x: f16) -> f16 {
170+
// SAFETY: `frinta` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
171+
unsafe {
172+
asm!(
173+
"frinta {x:h}, {x:h}",
174+
x = inout(vreg) x,
175+
options(nomem, nostack, pure)
176+
);
177+
}
178+
x
179+
}
180+
181+
pub fn roundeven(mut x: f64) -> f64 {
34182
// SAFETY: `frintn` is available with neon and has no side effects.
35-
//
36-
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
37-
// not support rounding modes.
38183
unsafe {
39184
asm!(
40185
"frintn {x:d}, {x:d}",
@@ -45,11 +190,8 @@ pub fn rint(mut x: f64) -> f64 {
45190
x
46191
}
47192

48-
pub fn rintf(mut x: f32) -> f32 {
193+
pub fn roundevenf(mut x: f32) -> f32 {
49194
// SAFETY: `frintn` is available with neon and has no side effects.
50-
//
51-
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
52-
// not support rounding modes.
53195
unsafe {
54196
asm!(
55197
"frintn {x:s}, {x:s}",
@@ -61,11 +203,8 @@ pub fn rintf(mut x: f32) -> f32 {
61203
}
62204

63205
#[cfg(all(f16_enabled, target_feature = "fp16"))]
64-
pub fn rintf16(mut x: f16) -> f16 {
206+
pub fn roundevenf16(mut x: f16) -> f16 {
65207
// SAFETY: `frintn` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
66-
//
67-
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
68-
// not support rounding modes.
69208
unsafe {
70209
asm!(
71210
"frintn {x:h}, {x:h}",
@@ -76,6 +215,43 @@ pub fn rintf16(mut x: f16) -> f16 {
76215
x
77216
}
78217

218+
pub fn trunc(mut x: f64) -> f64 {
219+
// SAFETY: `frintz` is available with neon and has no side effects.
220+
unsafe {
221+
asm!(
222+
"frintz {x:d}, {x:d}",
223+
x = inout(vreg) x,
224+
options(nomem, nostack, pure)
225+
);
226+
}
227+
x
228+
}
229+
230+
pub fn truncf(mut x: f32) -> f32 {
231+
// SAFETY: `frintz` is available with neon and has no side effects.
232+
unsafe {
233+
asm!(
234+
"frintz {x:s}, {x:s}",
235+
x = inout(vreg) x,
236+
options(nomem, nostack, pure)
237+
);
238+
}
239+
x
240+
}
241+
242+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
243+
pub fn truncf16(mut x: f16) -> f16 {
244+
// SAFETY: `frintz` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
245+
unsafe {
246+
asm!(
247+
"frintz {x:h}, {x:h}",
248+
x = inout(vreg) x,
249+
options(nomem, nostack, pure)
250+
);
251+
}
252+
x
253+
}
254+
79255
pub fn sqrt(mut x: f64) -> f64 {
80256
// SAFETY: `fsqrt` is available with neon and has no side effects.
81257
unsafe {

libm/src/math/arch/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,30 @@ cfg_if! {
2626
pub use aarch64::{
2727
fma,
2828
fmaf,
29+
ceil,
30+
ceilf,
31+
floor,
32+
floorf,
33+
round,
34+
roundf,
2935
rint,
3036
rintf,
37+
roundeven,
38+
roundevenf,
39+
trun,
40+
truncf
3141
sqrt,
3242
sqrtf,
3343
};
3444

3545
#[cfg(all(f16_enabled, target_feature = "fp16"))]
3646
pub use aarch64::{
47+
ceilf16,
48+
floorf16,
49+
roundf16,
3750
rintf16,
51+
roundevenf16,
52+
truncf16
3853
sqrtf16,
3954
};
4055
}

libm/src/math/ceil.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#[cfg(f16_enabled)]
55
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
66
pub fn ceilf16(x: f16) -> f16 {
7+
select_implementation! {
8+
name: ceilf16,
9+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
10+
args: x,
11+
}
12+
713
super::generic::ceil(x)
814
}
915

@@ -14,7 +20,10 @@ pub fn ceilf16(x: f16) -> f16 {
1420
pub fn ceilf(x: f32) -> f32 {
1521
select_implementation! {
1622
name: ceilf,
17-
use_arch: all(target_arch = "wasm32", intrinsics_enabled),
23+
use_arch: any(
24+
all(target_arch = "aarch64", target_feature = "neon"),
25+
all(target_arch = "wasm32", intrinsics_enabled),
26+
),
1827
args: x,
1928
}
2029

@@ -28,7 +37,10 @@ pub fn ceilf(x: f32) -> f32 {
2837
pub fn ceil(x: f64) -> f64 {
2938
select_implementation! {
3039
name: ceil,
31-
use_arch: all(target_arch = "wasm32", intrinsics_enabled),
40+
use_arch: any(
41+
all(target_arch = "aarch64", target_feature = "neon"),
42+
all(target_arch = "wasm32", intrinsics_enabled),
43+
),
3244
use_arch_required: all(target_arch = "x86", not(target_feature = "sse2")),
3345
args: x,
3446
}

libm/src/math/floor.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#[cfg(f16_enabled)]
55
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
66
pub fn floorf16(x: f16) -> f16 {
7+
select_implementation! {
8+
name: floorf16,
9+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
10+
args: x,
11+
}
12+
713
return super::generic::floor(x);
814
}
915

@@ -14,7 +20,10 @@ pub fn floorf16(x: f16) -> f16 {
1420
pub fn floor(x: f64) -> f64 {
1521
select_implementation! {
1622
name: floor,
17-
use_arch: all(target_arch = "wasm32", intrinsics_enabled),
23+
use_arch: any(
24+
all(target_arch = "aarch64", target_feature = "neon"),
25+
all(target_arch = "wasm32", intrinsics_enabled),
26+
),
1827
use_arch_required: all(target_arch = "x86", not(target_feature = "sse2")),
1928
args: x,
2029
}
@@ -29,7 +38,10 @@ pub fn floor(x: f64) -> f64 {
2938
pub fn floorf(x: f32) -> f32 {
3039
select_implementation! {
3140
name: floorf,
32-
use_arch: all(target_arch = "wasm32", intrinsics_enabled),
41+
use_arch: any(
42+
all(target_arch = "aarch64", target_feature = "neon"),
43+
all(target_arch = "wasm32", intrinsics_enabled),
44+
),
3345
args: x,
3446
}
3547

libm/src/math/round.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,36 @@
22
#[cfg(f16_enabled)]
33
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
44
pub fn roundf16(x: f16) -> f16 {
5+
select_implementation! {
6+
name: roundf16,
7+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
8+
args: x,
9+
}
10+
511
super::generic::round(x)
612
}
713

814
/// Round `x` to the nearest integer, breaking ties away from zero.
915
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
1016
pub fn roundf(x: f32) -> f32 {
17+
select_implementation! {
18+
name: roundf,
19+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
20+
args: x,
21+
}
22+
1123
super::generic::round(x)
1224
}
1325

1426
/// Round `x` to the nearest integer, breaking ties away from zero.
1527
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
1628
pub fn round(x: f64) -> f64 {
29+
select_implementation! {
30+
name: round,
31+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
32+
args: x,
33+
}
34+
1735
super::generic::round(x)
1836
}
1937

0 commit comments

Comments
 (0)