Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit b351893

Browse files
committed
Improve sqrt/sqrtf if stable intrinsics allow
1 parent a3a3595 commit b351893

File tree

2 files changed

+224
-188
lines changed

2 files changed

+224
-188
lines changed

src/math/sqrt.rs

Lines changed: 135 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -95,128 +95,146 @@ pub fn sqrt(x: f64) -> f64 {
9595
}
9696
}
9797
}
98-
let mut z: f64;
99-
let sign: Wrapping<u32> = Wrapping(0x80000000);
100-
let mut ix0: i32;
101-
let mut s0: i32;
102-
let mut q: i32;
103-
let mut m: i32;
104-
let mut t: i32;
105-
let mut i: i32;
106-
let mut r: Wrapping<u32>;
107-
let mut t1: Wrapping<u32>;
108-
let mut s1: Wrapping<u32>;
109-
let mut ix1: Wrapping<u32>;
110-
let mut q1: Wrapping<u32>;
98+
#[cfg(target_feature="sse2")]
99+
{
100+
// Note(Lokathor): If compile time settings allow, we just use SSE2, since
101+
// the sqrt in `std` on these platforms also compiles down to an SSE2
102+
// instruction.
103+
#[cfg(target_arch="x86")]
104+
use core::arch::x86::*;
105+
#[cfg(target_arch="x86_64")]
106+
use core::arch::x86_64::*;
107+
unsafe {
108+
let m = _mm_set_sd(x);
109+
let m_sqrt = _mm_sqrt_pd(m);
110+
_mm_cvtsd_f64(m_sqrt)
111+
}
112+
}
113+
#[cfg(not(target_feature="sse2"))]
114+
{
115+
let mut z: f64;
116+
let sign: Wrapping<u32> = Wrapping(0x80000000);
117+
let mut ix0: i32;
118+
let mut s0: i32;
119+
let mut q: i32;
120+
let mut m: i32;
121+
let mut t: i32;
122+
let mut i: i32;
123+
let mut r: Wrapping<u32>;
124+
let mut t1: Wrapping<u32>;
125+
let mut s1: Wrapping<u32>;
126+
let mut ix1: Wrapping<u32>;
127+
let mut q1: Wrapping<u32>;
111128

112-
ix0 = (x.to_bits() >> 32) as i32;
113-
ix1 = Wrapping(x.to_bits() as u32);
129+
ix0 = (x.to_bits() >> 32) as i32;
130+
ix1 = Wrapping(x.to_bits() as u32);
114131

115-
/* take care of Inf and NaN */
116-
if (ix0 & 0x7ff00000) == 0x7ff00000 {
117-
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
118-
}
119-
/* take care of zero */
120-
if ix0 <= 0 {
121-
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
122-
return x; /* sqrt(+-0) = +-0 */
123-
}
124-
if ix0 < 0 {
125-
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
126-
}
127-
}
128-
/* normalize x */
129-
m = ix0 >> 20;
130-
if m == 0 {
131-
/* subnormal x */
132-
while ix0 == 0 {
133-
m -= 21;
134-
ix0 |= (ix1 >> 11).0 as i32;
135-
ix1 <<= 21;
136-
}
137-
i = 0;
138-
while (ix0 & 0x00100000) == 0 {
139-
i += 1;
140-
ix0 <<= 1;
141-
}
142-
m -= i - 1;
143-
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
144-
ix1 = ix1 << i as usize;
145-
}
146-
m -= 1023; /* unbias exponent */
147-
ix0 = (ix0 & 0x000fffff) | 0x00100000;
148-
if (m & 1) == 1 {
149-
/* odd m, double x to make it even */
150-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
151-
ix1 += ix1;
152-
}
153-
m >>= 1; /* m = [m/2] */
132+
/* take care of Inf and NaN */
133+
if (ix0 & 0x7ff00000) == 0x7ff00000 {
134+
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
135+
}
136+
/* take care of zero */
137+
if ix0 <= 0 {
138+
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
139+
return x; /* sqrt(+-0) = +-0 */
140+
}
141+
if ix0 < 0 {
142+
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
143+
}
144+
}
145+
/* normalize x */
146+
m = ix0 >> 20;
147+
if m == 0 {
148+
/* subnormal x */
149+
while ix0 == 0 {
150+
m -= 21;
151+
ix0 |= (ix1 >> 11).0 as i32;
152+
ix1 <<= 21;
153+
}
154+
i = 0;
155+
while (ix0 & 0x00100000) == 0 {
156+
i += 1;
157+
ix0 <<= 1;
158+
}
159+
m -= i - 1;
160+
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
161+
ix1 = ix1 << i as usize;
162+
}
163+
m -= 1023; /* unbias exponent */
164+
ix0 = (ix0 & 0x000fffff) | 0x00100000;
165+
if (m & 1) == 1 {
166+
/* odd m, double x to make it even */
167+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
168+
ix1 += ix1;
169+
}
170+
m >>= 1; /* m = [m/2] */
154171

155-
/* generate sqrt(x) bit by bit */
156-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
157-
ix1 += ix1;
158-
q = 0; /* [q,q1] = sqrt(x) */
159-
q1 = Wrapping(0);
160-
s0 = 0;
161-
s1 = Wrapping(0);
162-
r = Wrapping(0x00200000); /* r = moving bit from right to left */
172+
/* generate sqrt(x) bit by bit */
173+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
174+
ix1 += ix1;
175+
q = 0; /* [q,q1] = sqrt(x) */
176+
q1 = Wrapping(0);
177+
s0 = 0;
178+
s1 = Wrapping(0);
179+
r = Wrapping(0x00200000); /* r = moving bit from right to left */
163180

164-
while r != Wrapping(0) {
165-
t = s0 + r.0 as i32;
166-
if t <= ix0 {
167-
s0 = t + r.0 as i32;
168-
ix0 -= t;
169-
q += r.0 as i32;
170-
}
171-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
172-
ix1 += ix1;
173-
r >>= 1;
174-
}
181+
while r != Wrapping(0) {
182+
t = s0 + r.0 as i32;
183+
if t <= ix0 {
184+
s0 = t + r.0 as i32;
185+
ix0 -= t;
186+
q += r.0 as i32;
187+
}
188+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
189+
ix1 += ix1;
190+
r >>= 1;
191+
}
175192

176-
r = sign;
177-
while r != Wrapping(0) {
178-
t1 = s1 + r;
179-
t = s0;
180-
if t < ix0 || (t == ix0 && t1 <= ix1) {
181-
s1 = t1 + r;
182-
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
183-
s0 += 1;
184-
}
185-
ix0 -= t;
186-
if ix1 < t1 {
187-
ix0 -= 1;
188-
}
189-
ix1 -= t1;
190-
q1 += r;
191-
}
192-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
193-
ix1 += ix1;
194-
r >>= 1;
195-
}
193+
r = sign;
194+
while r != Wrapping(0) {
195+
t1 = s1 + r;
196+
t = s0;
197+
if t < ix0 || (t == ix0 && t1 <= ix1) {
198+
s1 = t1 + r;
199+
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
200+
s0 += 1;
201+
}
202+
ix0 -= t;
203+
if ix1 < t1 {
204+
ix0 -= 1;
205+
}
206+
ix1 -= t1;
207+
q1 += r;
208+
}
209+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
210+
ix1 += ix1;
211+
r >>= 1;
212+
}
196213

197-
/* use floating add to find out rounding direction */
198-
if (ix0 as u32 | ix1.0) != 0 {
199-
z = 1.0 - TINY; /* raise inexact flag */
200-
if z >= 1.0 {
201-
z = 1.0 + TINY;
202-
if q1.0 == 0xffffffff {
203-
q1 = Wrapping(0);
204-
q += 1;
205-
} else if z > 1.0 {
206-
if q1.0 == 0xfffffffe {
207-
q += 1;
208-
}
209-
q1 += Wrapping(2);
210-
} else {
211-
q1 += q1 & Wrapping(1);
212-
}
213-
}
214-
}
215-
ix0 = (q >> 1) + 0x3fe00000;
216-
ix1 = q1 >> 1;
217-
if (q & 1) == 1 {
218-
ix1 |= sign;
214+
/* use floating add to find out rounding direction */
215+
if (ix0 as u32 | ix1.0) != 0 {
216+
z = 1.0 - TINY; /* raise inexact flag */
217+
if z >= 1.0 {
218+
z = 1.0 + TINY;
219+
if q1.0 == 0xffffffff {
220+
q1 = Wrapping(0);
221+
q += 1;
222+
} else if z > 1.0 {
223+
if q1.0 == 0xfffffffe {
224+
q += 1;
225+
}
226+
q1 += Wrapping(2);
227+
} else {
228+
q1 += q1 & Wrapping(1);
229+
}
230+
}
231+
}
232+
ix0 = (q >> 1) + 0x3fe00000;
233+
ix1 = q1 >> 1;
234+
if (q & 1) == 1 {
235+
ix1 |= sign;
236+
}
237+
ix0 += m << 20;
238+
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
219239
}
220-
ix0 += m << 20;
221-
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
222240
}

0 commit comments

Comments
 (0)