Skip to content

Commit 8db3885

Browse files
author
andy-thomason
committed
cos, sin and start of a*
1 parent f7441e2 commit 8db3885

File tree

1 file changed

+216
-11
lines changed

1 file changed

+216
-11
lines changed

crates/std_float/src/lib.rs

Lines changed: 216 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ pub trait StdFloat: Sealed + Sized {
117117
fn fract(self) -> Self;
118118

119119
fn sin(self) -> Self;
120+
121+
fn cos(self) -> Self;
122+
123+
fn tan(self) -> Self;
124+
125+
fn asin(self) -> Self;
126+
127+
fn acos(self) -> Self;
128+
129+
fn atan(self) -> Self;
120130
}
121131

122132
impl<const N: usize> Sealed for Simd<f32, N> where LaneCount<N>: SupportedLaneCount {}
@@ -135,6 +145,8 @@ where
135145
}
136146

137147
/// Calculate the sine of the angle
148+
/// Note: this is hand-edited from generated scalar code.
149+
/// In an ideal world, we would generate this directly by code transformation.
138150
#[inline]
139151
fn sin(self) -> Self {
140152
#[allow(non_snake_case)]
@@ -150,6 +162,93 @@ where
150162
.mul_add(x * x, Self::splat(6.28318452581127506328))
151163
* x
152164
}
165+
166+
fn cos(self) -> Self {
167+
#[allow(non_snake_case)]
168+
let RECIP_2PI = Self::splat(0.15915494);
169+
170+
let scaled = self * RECIP_2PI;
171+
let x = scaled - scaled.round();
172+
Self::splat(6.52865816174499269880)
173+
.mul_add(x * x, Self::splat(-25.97327546890330396608))
174+
.mul_add(x * x, Self::splat(60.17118230812820383560))
175+
.mul_add(x * x, Self::splat(-85.45091743827674607508))
176+
.mul_add(x * x, Self::splat(64.93918704099473042873))
177+
.mul_add(x * x, Self::splat(-19.73920667935656472596))
178+
.mul_add(x * x, Self::splat(1.00000000000000000000))
179+
}
180+
181+
fn tan(self) -> Self {
182+
use core::f32::consts::PI;
183+
let scaled: Self = self * Self::splat(1.0 / PI);
184+
let x: Self = scaled - scaled.round();
185+
let recip: Self = (x * x - Self::splat(0.25)).recip();
186+
let y: Self = Self::splat(0.01439730036301634345)
187+
.mul_add(x * x, Self::splat(0.02101734538976238579))
188+
.mul_add(x * x, Self::splat(0.05285888255895108345))
189+
.mul_add(x * x, Self::splat(0.13475448281475060771))
190+
.mul_add(x * x, Self::splat(0.55773663386075044866))
191+
.mul_add(x * x, Self::splat(-0.78539816491781455948))
192+
* x;
193+
y * recip
194+
}
195+
196+
fn asin(self) -> Self {
197+
use core::f32::consts::PI;
198+
let lim: Self = Self::splat(0.9);
199+
let c: Self = self.lanes_lt(Self::splat(0.0)).select(Self::splat(-PI / 2.0), Self::splat(PI / 2.0));
200+
let s: Self = self.lanes_lt(Self::splat(0.0)).select(Self::splat(-1.0), Self::splat(1.0));
201+
let x: Self = (self * self).lanes_lt(lim * lim).select(self, (Self::splat(1.0) - self * self).sqrt());
202+
let y: Self = Self::splat(4374.97702992533695457424)
203+
.mul_add(x * x, Self::splat(-13781.55764426881951685974))
204+
.mul_add(x * x, Self::splat(17105.69475701115952774357))
205+
.mul_add(x * x, Self::splat(-10486.64894150265898388567))
206+
.mul_add(x * x, Self::splat(3231.76028705607279348342))
207+
.mul_add(x * x, Self::splat(-447.56480696327035255708))
208+
.mul_add(x * x, Self::splat(21.78206149264184872939))
209+
.mul_add(x * x, Self::splat(0.84158415752395745675))
210+
* x;
211+
(self * self).lanes_lt(lim * lim).select(y, c - y * s)
212+
}
213+
214+
fn acos(self) -> Self {
215+
use core::f32::consts::PI;
216+
let lim: Self = Self::splat(0.9);
217+
let c: Self = self.lanes_lt(Self::splat(0.0)).select(Self::splat(PI), Self::splat(0.0));
218+
let s: Self = self.lanes_lt(Self::splat(0.0)).select(Self::splat(1.0), Self::splat(-1.0));
219+
let x: Self = (self * self).lanes_lt(lim * lim).select(self, (Self::splat(1.0) - self * self).sqrt());
220+
// let c: Self = select(self < 0.0, PI, 0.0);
221+
// let s: Self = select(self < 0.0, 1.0, -1.0);
222+
// let x: Self = select(self * self < lim * lim, self, (1.0 - self * self).sqrt());
223+
let y: Self = Self::splat(4374.97702992533695457424)
224+
.mul_add(x * x, Self::splat(-13781.55764426881951685974))
225+
.mul_add(x * x, Self::splat(17105.69475701115952774357))
226+
.mul_add(x * x, Self::splat(-10486.64894150265898388567))
227+
.mul_add(x * x, Self::splat(3231.76028705607279348342))
228+
.mul_add(x * x, Self::splat(-447.56480696327035255708))
229+
.mul_add(x * x, Self::splat(21.78206149264184872939))
230+
.mul_add(x * x, Self::splat(0.84158415752395745675))
231+
* x;
232+
(self * self).lanes_lt(lim * lim).select(y, c - y * s)
233+
}
234+
235+
fn atan(self) -> Self {
236+
use core::f32::consts::PI;
237+
let lim: Self = Self::splat(1.0);
238+
let c: Self = self.lanes_lt(Self::splat(0.0)).select(Self::splat(-PI / 2.0), Self::splat(PI / 2.0));
239+
let small = self.abs().lanes_lt(lim);
240+
let x: Self = small.select(self, self.recip());
241+
let y: Self = Self::splat(95.70126383842530559360)
242+
.mul_add(x * x, Self::splat(424.99907022806059540464))
243+
.mul_add(x * x, Self::splat(-767.48259680040570156003))
244+
.mul_add(x * x, Self::splat(714.51953012224223415829))
245+
.mul_add(x * x, Self::splat(-354.32654395426962592865))
246+
.mul_add(x * x, Self::splat(83.96179897148539189638))
247+
.mul_add(x * x, Self::splat(-6.23958170715441509270))
248+
.mul_add(x * x, Self::splat(1.05498514186427524914))
249+
* x;
250+
small.select(y, c - y)
251+
}
153252
}
154253

155254
impl<const N: usize> StdFloat for Simd<f64, N>
@@ -167,6 +266,31 @@ where
167266
fn sin(self) -> Self {
168267
self
169268
}
269+
270+
#[inline]
271+
fn cos(self) -> Self {
272+
self
273+
}
274+
275+
#[inline]
276+
fn tan(self) -> Self {
277+
self
278+
}
279+
280+
#[inline]
281+
fn asin(self) -> Self {
282+
self
283+
}
284+
285+
#[inline]
286+
fn acos(self) -> Self {
287+
self
288+
}
289+
290+
#[inline]
291+
fn atan(self) -> Self {
292+
self
293+
}
170294
}
171295

172296
#[cfg(test)]
@@ -188,6 +312,8 @@ mod tests {
188312
let _ = x.sin();
189313
}
190314

315+
const NUM_ITER: usize = 0x10000;
316+
191317
macro_rules! test_range {
192318
(
193319
min: $min: expr,
@@ -198,7 +324,6 @@ mod tests {
198324
scalar_type: $scalar_type: ty,
199325
vector_type: $vector_type: ty,
200326
) => {{
201-
const NUM_ITER: usize = 0x10000;
202327
let limit = <$vector_type>::splat($limit);
203328
let b = (($max) - ($min)) * (1.0 / NUM_ITER as $scalar_type);
204329
let a = $min;
@@ -213,49 +338,129 @@ mod tests {
213338
(fi + 3.0) * b + a,
214339
]);
215340
let yref = <$vector_type>::from_array([sf(x[0]), sf(x[1]), sf(x[2]), sf(x[3])]);
216-
assert!(((vf(x) - yref).abs().lanes_le(limit)).all());
341+
let y = vf(x);
342+
let e = (y - yref);
343+
if !(e.abs().lanes_le(limit)).all() {
344+
panic!("\nx ={:20.16?}\ne ={:20.16?}\nlimit ={:20.16?}\nvector={:20.16?}\nscalar={:20.16?}\nvector_fn={}", x, e, limit, y, yref, stringify!($vector_fn));
345+
}
217346
}
218347
}};
219348
}
220349

221350
#[test]
222351
fn sin_f32() {
223352
use core::f32::consts::PI;
224-
let ulp = (2.0_f32).powi(-23);
353+
let one_ulp = (2.0_f32).powi(-23);
225354

226-
// In the range +/- pi/4 the input has 1 ulp of error.
227355
test_range!(
228356
min: -PI/4.0,
229357
max: PI/4.0,
230-
limit: ulp * 1.0,
358+
limit: one_ulp * 1.0,
231359
scalar_fn: |x : f32| x.sin(),
232360
vector_fn: |x : f32x4| x.sin(),
233361
scalar_type: f32,
234362
vector_type: f32x4,
235363
);
236364

237-
// In the range +/- pi/2 the input and output has 2 ulp of error.
238365
test_range!(
239366
min: -PI/2.0,
240367
max: PI/2.0,
241-
limit: ulp * 2.0,
368+
limit: one_ulp * 2.0,
242369
scalar_fn: |x : f32| x.sin(),
243370
vector_fn: |x : f32x4| x.sin(),
244371
scalar_type: f32,
245372
vector_type: f32x4,
246373
);
247374

248-
// In the range +/- pi the input has 2 ulp of error and the output has 5.
249-
// Note that the scalar sin also has this error but the implementation
250-
// is different.
251375
test_range!(
252376
min: -PI,
253377
max: PI,
254-
limit: ulp * 5.0,
378+
limit: one_ulp * 8.0,
255379
scalar_fn: |x : f32| x.sin(),
256380
vector_fn: |x : f32x4| x.sin(),
257381
scalar_type: f32,
258382
vector_type: f32x4,
259383
);
260384
}
385+
386+
#[test]
387+
fn cos_f32() {
388+
use core::f32::consts::PI;
389+
let one_ulp = (2.0_f32).powi(-23);
390+
391+
// In the range +/- pi/4 the input has 1 ulp of error.
392+
test_range!(
393+
min: -PI/4.0,
394+
max: PI/4.0,
395+
limit: one_ulp * 1.0,
396+
scalar_fn: |x : f32| x.cos(),
397+
vector_fn: |x : f32x4| x.cos(),
398+
scalar_type: f32,
399+
vector_type: f32x4,
400+
);
401+
402+
// In the range +/- pi/2 the input and output has 2 ulp of error.
403+
test_range!(
404+
min: -PI/2.0,
405+
max: PI/2.0,
406+
limit: one_ulp * 2.0,
407+
scalar_fn: |x : f32| x.cos(),
408+
vector_fn: |x : f32x4| x.cos(),
409+
scalar_type: f32,
410+
vector_type: f32x4,
411+
);
412+
413+
// In the range +/- pi the input has 4 ulp of error and the output has 5.
414+
// Note that the scalar cos also has this error but the implementation
415+
// is different.
416+
test_range!(
417+
min: -PI,
418+
max: PI,
419+
limit: one_ulp * 8.0,
420+
scalar_fn: |x : f32| x.cos(),
421+
vector_fn: |x : f32x4| x.cos(),
422+
scalar_type: f32,
423+
vector_type: f32x4,
424+
);
425+
}
426+
427+
#[test]
428+
fn tan_f32() {
429+
use core::f32::consts::PI;
430+
let one_ulp = (2.0_f32).powi(-23);
431+
432+
// For the outsides, reciprocal accuracy is important.
433+
// Note that the vector function correctly gets -inf for -PI/2
434+
// but the scalar function does not.
435+
test_range!(
436+
min: -PI/2.0 + 0.00001,
437+
max: -PI/4.0,
438+
limit: one_ulp * 3.0,
439+
scalar_fn: |x : f32| x.tan().recip(),
440+
vector_fn: |x : f32x4| x.tan().recip(),
441+
scalar_type: f32,
442+
vector_type: f32x4,
443+
);
444+
445+
// For the insides, absolute accuracy is important.
446+
test_range!(
447+
min: -PI/4.0,
448+
max: PI/4.0,
449+
limit: one_ulp * 2.0,
450+
scalar_fn: |x : f32| x.tan(),
451+
vector_fn: |x : f32x4| x.tan(),
452+
scalar_type: f32,
453+
vector_type: f32x4,
454+
);
455+
456+
test_range!(
457+
min: PI/4.0,
458+
max: PI/2.0 - 0.00001,
459+
limit: one_ulp * 3.0,
460+
scalar_fn: |x : f32| x.tan().recip(),
461+
vector_fn: |x : f32x4| x.tan().recip(),
462+
scalar_type: f32,
463+
vector_type: f32x4,
464+
);
465+
}
261466
}

0 commit comments

Comments
 (0)