Skip to content

Commit cac4ecc

Browse files
committed
Rewrite BigDecimal#sqrt in ruby with improved Newton's method
1 parent 0d854c4 commit cac4ecc

File tree

4 files changed

+56
-211
lines changed

4 files changed

+56
-211
lines changed

ext/bigdecimal/bigdecimal.c

Lines changed: 0 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,31 +2354,6 @@ BigDecimal_abs(VALUE self)
23542354
return CheckGetValue(c);
23552355
}
23562356

2357-
/* call-seq:
2358-
* sqrt(n)
2359-
*
2360-
* Returns the square root of the value.
2361-
*
2362-
* Result has at least n significant digits.
2363-
*/
2364-
static VALUE
2365-
BigDecimal_sqrt(VALUE self, VALUE nFig)
2366-
{
2367-
ENTER(5);
2368-
BDVALUE c, a;
2369-
size_t mx, n;
2370-
2371-
GUARD_OBJ(a, GetBDValueMust(self));
2372-
mx = a.real->Prec * (VpBaseFig() + 1);
2373-
2374-
n = check_int_precision(nFig);
2375-
n += VpDblFig() + VpBaseFig();
2376-
if (mx <= n) mx = n;
2377-
GUARD_OBJ(c, NewZeroWrapLimited(1, mx));
2378-
VpSqrt(c.real, a.real);
2379-
return CheckGetValue(c);
2380-
}
2381-
23822357
/* Return the integer part of the number, as a BigDecimal.
23832358
*/
23842359
static VALUE
@@ -3916,7 +3891,6 @@ Init_bigdecimal(void)
39163891
rb_define_method(rb_cBigDecimal, "dup", BigDecimal_clone, 0);
39173892
rb_define_method(rb_cBigDecimal, "to_f", BigDecimal_to_f, 0);
39183893
rb_define_method(rb_cBigDecimal, "abs", BigDecimal_abs, 0);
3919-
rb_define_method(rb_cBigDecimal, "sqrt", BigDecimal_sqrt, 1);
39203894
rb_define_method(rb_cBigDecimal, "fix", BigDecimal_fix, 0);
39213895
rb_define_method(rb_cBigDecimal, "round", BigDecimal_round, -1);
39223896
rb_define_method(rb_cBigDecimal, "frac", BigDecimal_frac, 0);
@@ -3988,9 +3962,6 @@ static int gfDebug = 1; /* Debug switch */
39883962
#endif /* BIGDECIMAL_DEBUG */
39893963

39903964
static Real *VpConstOne; /* constant 1.0 */
3991-
static Real *VpConstPt5; /* constant 0.5 */
3992-
#define maxnr 100UL /* Maximum iterations for calculating sqrt. */
3993-
/* used in VpSqrt() */
39943965

39953966
enum op_sw {
39963967
OP_SW_ADD = 1, /* + */
@@ -4396,11 +4367,6 @@ VpInit(DECDIG BaseVal)
43964367
/* Const 1.0 */
43974368
VpConstOne = NewOneNolimit(1, 1);
43984369

4399-
/* Const 0.5 */
4400-
VpConstPt5 = NewOneNolimit(1, 1);
4401-
VpConstPt5->exponent = 0;
4402-
VpConstPt5->frac[0] = 5*BASE1;
4403-
44044370
#ifdef BIGDECIMAL_DEBUG
44054371
gnAlloc = 0;
44064372
#endif /* BIGDECIMAL_DEBUG */
@@ -6225,174 +6191,6 @@ VpVtoD(double *d, SIGNED_VALUE *e, Real *m)
62256191
return f;
62266192
}
62276193

6228-
/*
6229-
* m <- d
6230-
*/
6231-
VP_EXPORT void
6232-
VpDtoV(Real *m, double d)
6233-
{
6234-
size_t ind_m, mm;
6235-
SIGNED_VALUE ne;
6236-
DECDIG i;
6237-
double val, val2;
6238-
6239-
if (isnan(d)) {
6240-
VpSetNaN(m);
6241-
goto Exit;
6242-
}
6243-
if (isinf(d)) {
6244-
if (d > 0.0) VpSetPosInf(m);
6245-
else VpSetNegInf(m);
6246-
goto Exit;
6247-
}
6248-
6249-
if (d == 0.0) {
6250-
VpSetZero(m, 1);
6251-
goto Exit;
6252-
}
6253-
val = (d > 0.) ? d : -d;
6254-
ne = 0;
6255-
if (val >= 1.0) {
6256-
while (val >= 1.0) {
6257-
val /= (double)BASE;
6258-
++ne;
6259-
}
6260-
}
6261-
else {
6262-
val2 = 1.0 / (double)BASE;
6263-
while (val < val2) {
6264-
val *= (double)BASE;
6265-
--ne;
6266-
}
6267-
}
6268-
/* Now val = 0.xxxxx*BASE**ne */
6269-
6270-
mm = m->MaxPrec;
6271-
memset(m->frac, 0, mm * sizeof(DECDIG));
6272-
for (ind_m = 0; val > 0.0 && ind_m < mm; ind_m++) {
6273-
val *= (double)BASE;
6274-
i = (DECDIG)val;
6275-
val -= (double)i;
6276-
m->frac[ind_m] = i;
6277-
}
6278-
if (ind_m >= mm) ind_m = mm - 1;
6279-
VpSetSign(m, (d > 0.0) ? 1 : -1);
6280-
m->Prec = ind_m + 1;
6281-
m->exponent = ne;
6282-
6283-
VpInternalRound(m, 0, (m->Prec > 0) ? m->frac[m->Prec-1] : 0,
6284-
(DECDIG)(val*(double)BASE));
6285-
6286-
Exit:
6287-
return;
6288-
}
6289-
6290-
/*
6291-
* y = SQRT(x), y*y - x =>0
6292-
*/
6293-
VP_EXPORT int
6294-
VpSqrt(Real *y, Real *x)
6295-
{
6296-
Real *f = NULL;
6297-
Real *r = NULL;
6298-
size_t y_prec;
6299-
SIGNED_VALUE n, e;
6300-
ssize_t nr;
6301-
double val;
6302-
6303-
/* Zero or +Infinity ? */
6304-
if (VpIsZero(x) || VpIsPosInf(x)) {
6305-
VpAsgn(y,x,1);
6306-
goto Exit;
6307-
}
6308-
6309-
/* Negative ? */
6310-
if (BIGDECIMAL_NEGATIVE_P(x)) {
6311-
VpSetNaN(y);
6312-
return VpException(VP_EXCEPTION_OP, "sqrt of negative value", 0);
6313-
}
6314-
6315-
/* NaN ? */
6316-
if (VpIsNaN(x)) {
6317-
VpSetNaN(y);
6318-
return VpException(VP_EXCEPTION_OP, "sqrt of 'NaN'(Not a Number)", 0);
6319-
}
6320-
6321-
/* One ? */
6322-
if (VpIsOne(x)) {
6323-
VpSetOne(y);
6324-
goto Exit;
6325-
}
6326-
6327-
n = (SIGNED_VALUE)y->MaxPrec;
6328-
if (x->MaxPrec > (size_t)n) n = (ssize_t)x->MaxPrec;
6329-
6330-
/* allocate temporally variables */
6331-
/* TODO: reconsider MaxPrec of f and r */
6332-
f = NewOneNolimit(1, y->MaxPrec * (BASE_FIG + 2));
6333-
r = NewOneNolimit(1, (n + n) * (BASE_FIG + 2));
6334-
6335-
nr = 0;
6336-
y_prec = y->MaxPrec;
6337-
6338-
VpVtoD(&val, &e, x); /* val <- x */
6339-
e /= (SIGNED_VALUE)BASE_FIG;
6340-
n = e / 2;
6341-
if (e - n * 2 != 0) {
6342-
val /= BASE;
6343-
n = (e + 1) / 2;
6344-
}
6345-
VpDtoV(y, sqrt(val)); /* y <- sqrt(val) */
6346-
y->exponent += n;
6347-
n = (SIGNED_VALUE)roomof(BIGDECIMAL_DOUBLE_FIGURES, BASE_FIG);
6348-
y->MaxPrec = Min((size_t)n , y_prec);
6349-
f->MaxPrec = y->MaxPrec + 1;
6350-
n = (SIGNED_VALUE)(y_prec * BASE_FIG);
6351-
if (n > (SIGNED_VALUE)maxnr) n = (SIGNED_VALUE)maxnr;
6352-
6353-
/*
6354-
* Perform: y_{n+1} = (y_n - x/y_n) / 2
6355-
*/
6356-
do {
6357-
y->MaxPrec *= 2;
6358-
if (y->MaxPrec > y_prec) y->MaxPrec = y_prec;
6359-
f->MaxPrec = y->MaxPrec;
6360-
VpDivd(f, r, x, y); /* f = x/y */
6361-
VpAddSub(r, f, y, -1); /* r = f - y */
6362-
VpMult(f, VpConstPt5, r); /* f = 0.5*r */
6363-
if (y_prec == y->MaxPrec && VpIsZero(f))
6364-
goto converge;
6365-
VpAddSub(r, f, y, 1); /* r = y + f */
6366-
VpAsgn(y, r, 1); /* y = r */
6367-
} while (++nr < n);
6368-
6369-
#ifdef BIGDECIMAL_DEBUG
6370-
if (gfDebug) {
6371-
printf("ERROR(VpSqrt): did not converge within %ld iterations.\n", nr);
6372-
}
6373-
#endif /* BIGDECIMAL_DEBUG */
6374-
y->MaxPrec = y_prec;
6375-
6376-
converge:
6377-
VpChangeSign(y, 1);
6378-
#ifdef BIGDECIMAL_DEBUG
6379-
if (gfDebug) {
6380-
VpMult(r, y, y);
6381-
VpAddSub(f, x, r, -1);
6382-
printf("VpSqrt: iterations = %"PRIdSIZE"\n", nr);
6383-
VPrint(stdout, " y =% \n", y);
6384-
VPrint(stdout, " x =% \n", x);
6385-
VPrint(stdout, " x-y*y = % \n", f);
6386-
}
6387-
#endif /* BIGDECIMAL_DEBUG */
6388-
y->MaxPrec = y_prec;
6389-
6390-
Exit:
6391-
rbd_free_struct(f);
6392-
rbd_free_struct(r);
6393-
return 1;
6394-
}
6395-
63966194
/*
63976195
* Round relatively from the decimal point.
63986196
* f: rounding mode

ext/bigdecimal/bigdecimal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ typedef struct {
195195
*/
196196

197197
#define VpBaseFig() BIGDECIMAL_COMPONENT_FIGURES
198-
#define VpDblFig() BIGDECIMAL_DOUBLE_FIGURES
199198

200199
/* Zero,Inf,NaN (isinf(),isnan() used to check) */
201200
VP_EXPORT double VpGetDoubleNaN(void);
@@ -229,8 +228,6 @@ VP_EXPORT void VpToString(Real *a, char *buf, size_t bufsize, size_t fFmt, int f
229228
VP_EXPORT void VpToFString(Real *a, char *buf, size_t bufsize, size_t fFmt, int fPlus);
230229
VP_EXPORT int VpCtoV(Real *a, const char *int_chr, size_t ni, const char *frac, size_t nf, const char *exp_chr, size_t ne);
231230
VP_EXPORT int VpVtoD(double *d, SIGNED_VALUE *e, Real *m);
232-
VP_EXPORT void VpDtoV(Real *m,double d);
233-
VP_EXPORT int VpSqrt(Real *y,Real *x);
234231
VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il);
235232
VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf);
236233
VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf);

lib/bigdecimal.rb

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,37 @@ def power(y, prec = nil)
127127
end
128128
ans.mult(1, prec)
129129
end
130+
131+
# Returns the square root of the value.
132+
#
133+
# Result has at least prec significant digits.
134+
#
135+
def sqrt(prec)
136+
if infinite? == 1
137+
exception_mode = BigDecimal.mode(BigDecimal::EXCEPTION_ALL)
138+
raise FloatDomainError, "Computation results in 'Infinity'" if exception_mode.anybits?(BigDecimal::EXCEPTION_INFINITY)
139+
return INFINITY
140+
end
141+
raise ArgumentError, 'negative precision' if prec < 0
142+
raise FloatDomainError, 'sqrt of negative value' if self < 0
143+
raise FloatDomainError, "sqrt of 'NaN'(Not a Number)" if nan?
144+
return self if zero?
145+
146+
# BigDecimal#sqrt calculates at least n_significant_digits precision.
147+
# This feature maybe problematic for some cases.
148+
n_digits = n_significant_digits
149+
prec = [prec, n_digits].max
150+
151+
ex = exponent / 2
152+
x = self * BigDecimal("1e#{-ex * 2}")
153+
y = BigDecimal(Math.sqrt(x.to_f))
154+
precs = [prec + BigDecimal.double_fig]
155+
precs << 2 + precs.last / 2 while precs.last > BigDecimal.double_fig
156+
precs.reverse_each do |p|
157+
y = (y + x.div(y, p)).div(2, p)
158+
end
159+
y * BigDecimal("1e#{ex}")
160+
end
130161
end
131162

132163
# Core BigMath methods for BigDecimal (log, exp) are defined here.
@@ -186,12 +217,7 @@ def self.log(x, prec)
186217
prec += BigDecimal.double_fig
187218

188219
# log(x) = log(sqrt(sqrt(sqrt(sqrt(x))))) * 2**sqrt_steps
189-
sqrt_steps = [2 * Integer.sqrt(prec) + 3 * x_minus_one_exponent, 0].max
190-
191-
# Reduce sqrt_step until sqrt gets fast
192-
# https://github.com/ruby/bigdecimal/pull/323
193-
# https://github.com/ruby/bigdecimal/pull/343
194-
sqrt_steps /= 10
220+
sqrt_steps = [Integer.sqrt(prec) + 3 * x_minus_one_exponent, 0].max
195221

196222
lg2 = 0.3010299956639812
197223
prec2 = prec + [-x_minus_one_exponent, 0].max + (sqrt_steps * lg2).ceil

test/bigdecimal/test_bigdecimal.rb

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,16 @@ def test_sqrt_bigdecimal
13591359
assert_equal(0, BigDecimal("-0").sqrt(1))
13601360
assert_equal(1, BigDecimal("1").sqrt(1))
13611361
assert_positive_infinite(BigDecimal("Infinity").sqrt(1))
1362+
1363+
# Out of float range
1364+
assert_equal(BigDecimal('12e1024'), BigDecimal('144e2048').sqrt(10))
1365+
assert_equal(BigDecimal('12e-1024'), BigDecimal('144e-2048').sqrt(10))
1366+
1367+
sqrt2_300 = BigDecimal(2).sqrt(300)
1368+
(250..270).each do |prec|
1369+
sqrt_prec = prec + BigDecimal.double_fig - 1
1370+
assert_in_delta(sqrt2_300, BigDecimal(2).sqrt(prec), BigDecimal("1e#{-sqrt_prec}"))
1371+
end
13621372
end
13631373

13641374
def test_sqrt_5266
@@ -1375,6 +1385,20 @@ def test_sqrt_5266
13751385
x.sqrt(109).to_s(109).split(' ')[0])
13761386
end
13771387

1388+
def test_sqrt_minimum_precision
1389+
x = BigDecimal((2**200).to_s)
1390+
assert_equal(2**100, x.sqrt(1))
1391+
1392+
x = BigDecimal('1' * 60 + '.' + '1' * 40)
1393+
assert_in_delta(BigDecimal('3' * 30 + '.' + '3' * 70), x.sqrt(1), BigDecimal('1e-70'))
1394+
1395+
x = BigDecimal('1' * 40 + '.' + '1' * 60)
1396+
assert_in_delta(BigDecimal('3' * 20 + '.' + '3' * 80), x.sqrt(1), BigDecimal('1e-80'))
1397+
1398+
x = BigDecimal('0.' + '0' * 50 + '1' * 100)
1399+
assert_in_delta(BigDecimal('0.' + '0' * 25 + '3' * 100), x.sqrt(1), BigDecimal('1e-125'))
1400+
end
1401+
13781402
def test_fix
13791403
x = BigDecimal("1.1")
13801404
assert_equal(1, x.fix)

0 commit comments

Comments
 (0)