Skip to content

Commit f045650

Browse files
committed
Rewrite BigDecimal#sqrt in ruby with improved Newton's method
1 parent 073edce commit f045650

File tree

4 files changed

+58
-205
lines changed

4 files changed

+58
-205
lines changed

ext/bigdecimal/bigdecimal.c

Lines changed: 0 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,31 +2355,6 @@ BigDecimal_abs(VALUE self)
23552355
return CheckGetValue(c);
23562356
}
23572357

2358-
/* call-seq:
2359-
* sqrt(n)
2360-
*
2361-
* Returns the square root of the value.
2362-
*
2363-
* Result has at least n significant digits.
2364-
*/
2365-
static VALUE
2366-
BigDecimal_sqrt(VALUE self, VALUE nFig)
2367-
{
2368-
ENTER(5);
2369-
BDVALUE c, a;
2370-
size_t mx, n;
2371-
2372-
GUARD_OBJ(a, GetBDValueMust(self));
2373-
mx = a.real->Prec * (VpBaseFig() + 1);
2374-
2375-
n = check_int_precision(nFig);
2376-
n += VpDblFig() + VpBaseFig();
2377-
if (mx <= n) mx = n;
2378-
GUARD_OBJ(c, NewZeroWrapLimited(1, mx));
2379-
VpSqrt(c.real, a.real);
2380-
return CheckGetValue(c);
2381-
}
2382-
23832358
/* Return the integer part of the number, as a BigDecimal.
23842359
*/
23852360
static VALUE
@@ -4577,7 +4552,6 @@ Init_bigdecimal(void)
45774552
rb_define_method(rb_cBigDecimal, "dup", BigDecimal_clone, 0);
45784553
rb_define_method(rb_cBigDecimal, "to_f", BigDecimal_to_f, 0);
45794554
rb_define_method(rb_cBigDecimal, "abs", BigDecimal_abs, 0);
4580-
rb_define_method(rb_cBigDecimal, "sqrt", BigDecimal_sqrt, 1);
45814555
rb_define_method(rb_cBigDecimal, "fix", BigDecimal_fix, 0);
45824556
rb_define_method(rb_cBigDecimal, "round", BigDecimal_round, -1);
45834557
rb_define_method(rb_cBigDecimal, "frac", BigDecimal_frac, 0);
@@ -4655,9 +4629,6 @@ static int gfDebug = 1; /* Debug switch */
46554629
#endif /* BIGDECIMAL_DEBUG */
46564630

46574631
static Real *VpConstOne; /* constant 1.0 */
4658-
static Real *VpConstPt5; /* constant 0.5 */
4659-
#define maxnr 100UL /* Maximum iterations for calculating sqrt. */
4660-
/* used in VpSqrt() */
46614632

46624633
enum op_sw {
46634634
OP_SW_ADD = 1, /* + */
@@ -5063,11 +5034,6 @@ VpInit(DECDIG BaseVal)
50635034
/* Const 1.0 */
50645035
VpConstOne = NewOneNolimit(1, 1);
50655036

5066-
/* Const 0.5 */
5067-
VpConstPt5 = NewOneNolimit(1, 1);
5068-
VpConstPt5->exponent = 0;
5069-
VpConstPt5->frac[0] = 5*BASE1;
5070-
50715037
#ifdef BIGDECIMAL_DEBUG
50725038
gnAlloc = 0;
50735039
#endif /* BIGDECIMAL_DEBUG */
@@ -6892,174 +6858,6 @@ VpVtoD(double *d, SIGNED_VALUE *e, Real *m)
68926858
return f;
68936859
}
68946860

6895-
/*
6896-
* m <- d
6897-
*/
6898-
VP_EXPORT void
6899-
VpDtoV(Real *m, double d)
6900-
{
6901-
size_t ind_m, mm;
6902-
SIGNED_VALUE ne;
6903-
DECDIG i;
6904-
double val, val2;
6905-
6906-
if (isnan(d)) {
6907-
VpSetNaN(m);
6908-
goto Exit;
6909-
}
6910-
if (isinf(d)) {
6911-
if (d > 0.0) VpSetPosInf(m);
6912-
else VpSetNegInf(m);
6913-
goto Exit;
6914-
}
6915-
6916-
if (d == 0.0) {
6917-
VpSetZero(m, 1);
6918-
goto Exit;
6919-
}
6920-
val = (d > 0.) ? d : -d;
6921-
ne = 0;
6922-
if (val >= 1.0) {
6923-
while (val >= 1.0) {
6924-
val /= (double)BASE;
6925-
++ne;
6926-
}
6927-
}
6928-
else {
6929-
val2 = 1.0 / (double)BASE;
6930-
while (val < val2) {
6931-
val *= (double)BASE;
6932-
--ne;
6933-
}
6934-
}
6935-
/* Now val = 0.xxxxx*BASE**ne */
6936-
6937-
mm = m->MaxPrec;
6938-
memset(m->frac, 0, mm * sizeof(DECDIG));
6939-
for (ind_m = 0; val > 0.0 && ind_m < mm; ind_m++) {
6940-
val *= (double)BASE;
6941-
i = (DECDIG)val;
6942-
val -= (double)i;
6943-
m->frac[ind_m] = i;
6944-
}
6945-
if (ind_m >= mm) ind_m = mm - 1;
6946-
VpSetSign(m, (d > 0.0) ? 1 : -1);
6947-
m->Prec = ind_m + 1;
6948-
m->exponent = ne;
6949-
6950-
VpInternalRound(m, 0, (m->Prec > 0) ? m->frac[m->Prec-1] : 0,
6951-
(DECDIG)(val*(double)BASE));
6952-
6953-
Exit:
6954-
return;
6955-
}
6956-
6957-
/*
6958-
* y = SQRT(x), y*y - x =>0
6959-
*/
6960-
VP_EXPORT int
6961-
VpSqrt(Real *y, Real *x)
6962-
{
6963-
Real *f = NULL;
6964-
Real *r = NULL;
6965-
size_t y_prec;
6966-
SIGNED_VALUE n, e;
6967-
ssize_t nr;
6968-
double val;
6969-
6970-
/* Zero or +Infinity ? */
6971-
if (VpIsZero(x) || VpIsPosInf(x)) {
6972-
VpAsgn(y,x,1);
6973-
goto Exit;
6974-
}
6975-
6976-
/* Negative ? */
6977-
if (BIGDECIMAL_NEGATIVE_P(x)) {
6978-
VpSetNaN(y);
6979-
return VpException(VP_EXCEPTION_OP, "sqrt of negative value", 0);
6980-
}
6981-
6982-
/* NaN ? */
6983-
if (VpIsNaN(x)) {
6984-
VpSetNaN(y);
6985-
return VpException(VP_EXCEPTION_OP, "sqrt of 'NaN'(Not a Number)", 0);
6986-
}
6987-
6988-
/* One ? */
6989-
if (VpIsOne(x)) {
6990-
VpSetOne(y);
6991-
goto Exit;
6992-
}
6993-
6994-
n = (SIGNED_VALUE)y->MaxPrec;
6995-
if (x->MaxPrec > (size_t)n) n = (ssize_t)x->MaxPrec;
6996-
6997-
/* allocate temporally variables */
6998-
/* TODO: reconsider MaxPrec of f and r */
6999-
f = NewOneNolimit(1, y->MaxPrec * (BASE_FIG + 2));
7000-
r = NewOneNolimit(1, (n + n) * (BASE_FIG + 2));
7001-
7002-
nr = 0;
7003-
y_prec = y->MaxPrec;
7004-
7005-
VpVtoD(&val, &e, x); /* val <- x */
7006-
e /= (SIGNED_VALUE)BASE_FIG;
7007-
n = e / 2;
7008-
if (e - n * 2 != 0) {
7009-
val /= BASE;
7010-
n = (e + 1) / 2;
7011-
}
7012-
VpDtoV(y, sqrt(val)); /* y <- sqrt(val) */
7013-
y->exponent += n;
7014-
n = (SIGNED_VALUE)roomof(BIGDECIMAL_DOUBLE_FIGURES, BASE_FIG);
7015-
y->MaxPrec = Min((size_t)n , y_prec);
7016-
f->MaxPrec = y->MaxPrec + 1;
7017-
n = (SIGNED_VALUE)(y_prec * BASE_FIG);
7018-
if (n > (SIGNED_VALUE)maxnr) n = (SIGNED_VALUE)maxnr;
7019-
7020-
/*
7021-
* Perform: y_{n+1} = (y_n - x/y_n) / 2
7022-
*/
7023-
do {
7024-
y->MaxPrec *= 2;
7025-
if (y->MaxPrec > y_prec) y->MaxPrec = y_prec;
7026-
f->MaxPrec = y->MaxPrec;
7027-
VpDivd(f, r, x, y); /* f = x/y */
7028-
VpAddSub(r, f, y, -1); /* r = f - y */
7029-
VpMult(f, VpConstPt5, r); /* f = 0.5*r */
7030-
if (y_prec == y->MaxPrec && VpIsZero(f))
7031-
goto converge;
7032-
VpAddSub(r, f, y, 1); /* r = y + f */
7033-
VpAsgn(y, r, 1); /* y = r */
7034-
} while (++nr < n);
7035-
7036-
#ifdef BIGDECIMAL_DEBUG
7037-
if (gfDebug) {
7038-
printf("ERROR(VpSqrt): did not converge within %ld iterations.\n", nr);
7039-
}
7040-
#endif /* BIGDECIMAL_DEBUG */
7041-
y->MaxPrec = y_prec;
7042-
7043-
converge:
7044-
VpChangeSign(y, 1);
7045-
#ifdef BIGDECIMAL_DEBUG
7046-
if (gfDebug) {
7047-
VpMult(r, y, y);
7048-
VpAddSub(f, x, r, -1);
7049-
printf("VpSqrt: iterations = %"PRIdSIZE"\n", nr);
7050-
VPrint(stdout, " y =% \n", y);
7051-
VPrint(stdout, " x =% \n", x);
7052-
VPrint(stdout, " x-y*y = % \n", f);
7053-
}
7054-
#endif /* BIGDECIMAL_DEBUG */
7055-
y->MaxPrec = y_prec;
7056-
7057-
Exit:
7058-
rbd_free_struct(f);
7059-
rbd_free_struct(r);
7060-
return 1;
7061-
}
7062-
70636861
/*
70646862
* Round relatively from the decimal point.
70656863
* f: rounding mode

ext/bigdecimal/bigdecimal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ typedef struct {
195195
*/
196196

197197
#define VpBaseFig() BIGDECIMAL_COMPONENT_FIGURES
198-
#define VpDblFig() BIGDECIMAL_DOUBLE_FIGURES
199198

200199
/* Zero,Inf,NaN (isinf(),isnan() used to check) */
201200
VP_EXPORT double VpGetDoubleNaN(void);
@@ -229,8 +228,6 @@ VP_EXPORT void VpToString(Real *a, char *buf, size_t bufsize, size_t fFmt, int f
229228
VP_EXPORT void VpToFString(Real *a, char *buf, size_t bufsize, size_t fFmt, int fPlus);
230229
VP_EXPORT int VpCtoV(Real *a, const char *int_chr, size_t ni, const char *frac, size_t nf, const char *exp_chr, size_t ne);
231230
VP_EXPORT int VpVtoD(double *d, SIGNED_VALUE *e, Real *m);
232-
VP_EXPORT void VpDtoV(Real *m,double d);
233-
VP_EXPORT int VpSqrt(Real *y,Real *x);
234231
VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il);
235232
VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf);
236233
VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf);

lib/bigdecimal.rb

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,37 @@
33
else
44
require 'bigdecimal.so'
55
end
6+
7+
class BigDecimal
8+
9+
# Returns the square root of the value.
10+
#
11+
# Result has at least prec significant digits.
12+
#
13+
def sqrt(prec)
14+
if infinite? == 1
15+
exception_mode = BigDecimal.mode(BigDecimal::EXCEPTION_ALL)
16+
raise FloatDomainError, "Computation results in 'Infinity'" if exception_mode.anybits?(BigDecimal::EXCEPTION_INFINITY)
17+
return INFINITY
18+
end
19+
raise ArgumentError, 'negative precision' if prec < 0
20+
raise FloatDomainError, 'sqrt of negative value' if self < 0
21+
raise FloatDomainError, "sqrt of 'NaN'(Not a Number)" if nan?
22+
return self if zero?
23+
24+
# BigDecimal#sqrt calculates at least n_significant_digits precision.
25+
# This feature maybe problematic for some cases.
26+
n_digits = n_significant_digits
27+
prec = [prec, n_digits].max
28+
29+
ex = exponent / 2
30+
x = self * BigDecimal("1e#{-ex * 2}")
31+
y = BigDecimal(Math.sqrt(x.to_f))
32+
precs = [prec + BigDecimal.double_fig]
33+
precs << 2 + precs.last / 2 while precs.last > BigDecimal.double_fig
34+
precs.reverse_each do |p|
35+
y = (y + x.div(y, p)).div(2, p)
36+
end
37+
y * BigDecimal("1e#{ex}")
38+
end
39+
end

test/bigdecimal/test_bigdecimal.rb

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,16 @@ def test_sqrt_bigdecimal
13591359
assert_equal(0, BigDecimal("-0").sqrt(1))
13601360
assert_equal(1, BigDecimal("1").sqrt(1))
13611361
assert_positive_infinite(BigDecimal("Infinity").sqrt(1))
1362+
1363+
# Out of float range
1364+
assert_equal(BigDecimal('12e1024'), BigDecimal('144e2048').sqrt(10))
1365+
assert_equal(BigDecimal('12e-1024'), BigDecimal('144e-2048').sqrt(10))
1366+
1367+
sqrt2_300 = BigDecimal(2).sqrt(300)
1368+
(250..270).each do |prec|
1369+
sqrt_prec = prec + BigDecimal.double_fig - 1
1370+
assert_in_delta(sqrt2_300, BigDecimal(2).sqrt(prec), BigDecimal("1e#{-sqrt_prec}"))
1371+
end
13621372
end
13631373

13641374
def test_sqrt_5266
@@ -1375,6 +1385,20 @@ def test_sqrt_5266
13751385
x.sqrt(109).to_s(109).split(' ')[0])
13761386
end
13771387

1388+
def test_sqrt_minimum_precision
1389+
x = BigDecimal((2**200).to_s)
1390+
assert_equal(2**100, x.sqrt(1))
1391+
1392+
x = BigDecimal('1' * 60 + '.' + '1' * 40)
1393+
assert_in_delta(BigDecimal('3' * 30 + '.' + '3' * 70), x.sqrt(1), BigDecimal('1e-70'))
1394+
1395+
x = BigDecimal('1' * 40 + '.' + '1' * 60)
1396+
assert_in_delta(BigDecimal('3' * 20 + '.' + '3' * 80), x.sqrt(1), BigDecimal('1e-80'))
1397+
1398+
x = BigDecimal('0.' + '0' * 50 + '1' * 100)
1399+
assert_in_delta(BigDecimal('0.' + '0' * 25 + '3' * 100), x.sqrt(1), BigDecimal('1e-125'))
1400+
end
1401+
13781402
def test_fix
13791403
x = BigDecimal("1.1")
13801404
assert_equal(1, x.fix)

0 commit comments

Comments
 (0)