diff --git a/bigdecimal.gemspec b/bigdecimal.gemspec index b6ef8fd9..6b20ac08 100644 --- a/bigdecimal.gemspec +++ b/bigdecimal.gemspec @@ -43,9 +43,11 @@ Gem::Specification.new do |s| ext/bigdecimal/bigdecimal.c ext/bigdecimal/bigdecimal.h ext/bigdecimal/bits.h + ext/bigdecimal/div.h ext/bigdecimal/feature.h ext/bigdecimal/missing.c ext/bigdecimal/missing.h + ext/bigdecimal/ntt.h ext/bigdecimal/missing/dtoa.c ext/bigdecimal/static_assert.h ] diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index 88f26f2b..3afe2dc5 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -29,10 +29,18 @@ #endif #include "bits.h" +#include "div.h" #include "static_assert.h" #define BIGDECIMAL_VERSION "3.3.0" +#if SIZEOF_DECDIG == 4 +#define USE_NTT_MULTIPLICATION 1 +#include "ntt.h" +#define NTT_MULTIPLICATION_THRESHOLD 100 +#define NEWTON_RAPHSON_DIVISION_THRESHOLD 200 +#endif + /* #define ENABLE_NUMERIC_STRING */ #define SIGNED_VALUE_MAX INTPTR_MAX @@ -75,11 +83,6 @@ static struct { uint8_t mode; } rbd_rounding_modes[RBD_NUM_ROUNDING_MODES]; -typedef struct { - VALUE bigdecimal; - Real *real; -} BDVALUE; - typedef struct { VALUE bigdecimal_or_nil; Real *real_or_null; @@ -207,7 +210,6 @@ rbd_allocate_struct_zero(int sign, size_t const digits) static unsigned short VpGetException(void); static void VpSetException(unsigned short f); static void VpCheckException(Real *p, bool always); -static int AddExponent(Real *a, SIGNED_VALUE n); static VALUE CheckGetValue(BDVALUE v); static void VpInternalRound(Real *c, size_t ixDigit, DECDIG vPrev, DECDIG v); static int VpLimitRound(Real *c, size_t ixDigit); @@ -1112,9 +1114,6 @@ BigDecimal_check_num(Real *p) VpCheckException(p, true); } -static VALUE BigDecimal_fix(VALUE self); -static VALUE BigDecimal_split(VALUE self); - /* Returns the value as an Integer. * * If the BigDecimal is infinity or NaN, raises FloatDomainError. @@ -3256,19 +3255,39 @@ BigDecimal_literal(const char *str) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS VALUE -BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { - BDVALUE a,b,c,d; +BigDecimal_vpdivd_generic(VALUE self, VALUE r, VALUE cprec, void (*vpdivd_func)(Real*, Real*, Real*, Real*)) { + BDVALUE a, b, c, d; size_t cn = NUM2INT(cprec); a = GetBDValueMust(self); b = GetBDValueMust(r); c = NewZeroWrap(1, cn * BASE_FIG); d = NewZeroWrap(1, VPDIVD_REM_PREC(a.real, b.real, c.real) * BASE_FIG); - VpDivd(c.real, d.real, a.real, b.real); + vpdivd_func(c.real, d.real, a.real, b.real); RB_GC_GUARD(a.bigdecimal); RB_GC_GUARD(b.bigdecimal); return rb_assoc_new(c.bigdecimal, d.bigdecimal); } +void +VpDivdNormal(Real *c, Real *r, Real *a, Real *b) { + VpDivd(c, r, a, b); +} + +VALUE +BigDecimal_vpdivd(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNormal); +} + +VALUE +BigDecimal_vpdivd_newton(VALUE self, VALUE r, VALUE cprec) { + return BigDecimal_vpdivd_generic(self, r, cprec, VpDivdNewton); +} + +VALUE +BigDecimal_newton_raphson_inverse(VALUE self, VALUE prec) { + return newton_raphson_inverse(self, NUM2SIZET(prec)); +} + VALUE BigDecimal_vpmult(VALUE self, VALUE v) { BDVALUE a,b,c; @@ -3280,6 +3299,25 @@ BigDecimal_vpmult(VALUE self, VALUE v) { RB_GC_GUARD(b.bigdecimal); return c.bigdecimal; } + +#if SIZEOF_DECDIG == 4 +VALUE +BigDecimal_nttmult(VALUE self, VALUE v) { + BDVALUE a,b,c; + a = GetBDValueMust(self); + b = GetBDValueMust(v); + c = NewZeroWrap(1, VPMULT_RESULT_PREC(a.real, b.real) * BASE_FIG); + ntt_multiply(a.real->Prec, b.real->Prec, a.real->frac, b.real->frac, c.real->frac); + VpSetSign(c.real, a.real->sign * b.real->sign); + c.real->exponent = a.real->exponent + b.real->exponent; + c.real->Prec = a.real->Prec + b.real->Prec; + VpNmlz(c.real); + RB_GC_GUARD(a.bigdecimal); + RB_GC_GUARD(b.bigdecimal); + return c.bigdecimal; +} +#endif + #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ /* Document-class: BigDecimal @@ -3651,7 +3689,12 @@ Init_bigdecimal(void) #ifdef BIGDECIMAL_USE_VP_TEST_METHODS rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2); + rb_define_method(rb_cBigDecimal, "vpdivd_newton", BigDecimal_vpdivd_newton, 2); + rb_define_method(rb_cBigDecimal, "newton_raphson_inverse", BigDecimal_newton_raphson_inverse, 1); rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1); +#ifdef USE_NTT_MULTIPLICATION + rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1); +#endif #endif /* BIGDECIMAL_USE_VP_TEST_METHODS */ #define ROUNDING_MODE(i, name, value) \ @@ -4934,6 +4977,15 @@ VpMult(Real *c, Real *a, Real *b) c->exponent = a->exponent; /* set exponent */ VpSetSign(c, VpGetSign(a) * VpGetSign(b)); /* set sign */ if (!AddExponent(c, b->exponent)) return 0; + +#ifdef USE_NTT_MULTIPLICATION + if (b->Prec >= NTT_MULTIPLICATION_THRESHOLD) { + ntt_multiply((uint32_t)a->Prec, (uint32_t)b->Prec, a->frac, b->frac, c->frac); + c->Prec = a->Prec + b->Prec; + goto Cleanup; + } +#endif + carry = 0; nc = ind_c = MxIndAB; memset(c->frac, 0, (nc + 1) * sizeof(DECDIG)); /* Initialize c */ @@ -4980,6 +5032,8 @@ VpMult(Real *c, Real *a, Real *b) } } } + +Cleanup: VpNmlz(c); Exit: @@ -5027,6 +5081,14 @@ VpDivd(Real *c, Real *r, Real *a, Real *b) if (word_a > word_r || word_b + word_c - 2 >= word_r) goto space_error; +#ifdef USE_NTT_MULTIPLICATION + // Newton-Raphson division requires multiplication to be faster than O(n^2) + if (word_c >= NEWTON_RAPHSON_DIVISION_THRESHOLD && word_b >= NEWTON_RAPHSON_DIVISION_THRESHOLD) { + VpDivdNewton(c, r, a, b); + goto Exit; + } +#endif + for (i = 0; i < word_a; ++i) r->frac[i] = a->frac[i]; for (i = word_a; i < word_r; ++i) r->frac[i] = 0; for (i = 0; i < word_c; ++i) c->frac[i] = 0; diff --git a/ext/bigdecimal/bigdecimal.h b/ext/bigdecimal/bigdecimal.h index 82c88a2a..71ddb21f 100644 --- a/ext/bigdecimal/bigdecimal.h +++ b/ext/bigdecimal/bigdecimal.h @@ -188,6 +188,11 @@ typedef struct { DECDIG frac[FLEXIBLE_ARRAY_SIZE]; /* Array of fraction part. */ } Real; +typedef struct { + VALUE bigdecimal; + Real *real; +} BDVALUE; + /* * ------------------ * EXPORTables. @@ -232,10 +237,31 @@ VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il); VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT void VpFrac(Real *y, Real *x); +VP_EXPORT int AddExponent(Real *a, SIGNED_VALUE n); /* VP constants */ VP_EXPORT Real *VpOne(void); +/* + * **** BigDecimal part **** + */ +VP_EXPORT VALUE BigDecimal_lt(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_ge(VALUE self, VALUE r); +VP_EXPORT VALUE BigDecimal_exponent(VALUE self); +VP_EXPORT VALUE BigDecimal_fix(VALUE self); +VP_EXPORT VALUE BigDecimal_frac(VALUE self); +VP_EXPORT VALUE BigDecimal_add(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_sub(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_mult(VALUE self, VALUE b); +VP_EXPORT VALUE BigDecimal_add2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_sub2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_mult2(VALUE self, VALUE b, VALUE n); +VP_EXPORT VALUE BigDecimal_split(VALUE self); +VP_EXPORT VALUE BigDecimal_decimal_shift(VALUE self, VALUE v); +VP_EXPORT inline BDVALUE GetBDValueMust(VALUE v); +VP_EXPORT inline BDVALUE rbd_allocate_struct_zero_wrap(int sign, size_t const digits); +#define NewZeroWrap rbd_allocate_struct_zero_wrap + /* * ------------------ * MACRO definitions. diff --git a/ext/bigdecimal/div.h b/ext/bigdecimal/div.h new file mode 100644 index 00000000..e6dd89c9 --- /dev/null +++ b/ext/bigdecimal/div.h @@ -0,0 +1,192 @@ +// Calculate the inverse of x using the Newton-Raphson method. +static VALUE +newton_raphson_inverse(VALUE x, size_t prec) { + BDVALUE bdone = NewZeroWrap(1, 1); + VpSetOne(bdone.real); + VALUE one = bdone.bigdecimal; + + // Initial approximation in 2 digits + BDVALUE bdx = GetBDValueMust(x); + BDVALUE inv0 = NewZeroWrap(1, 2 * BIGDECIMAL_COMPONENT_FIGURES); + VpSetOne(inv0.real); + DECDIG_DBL numerator = (DECDIG_DBL)BIGDECIMAL_BASE * 100; + DECDIG_DBL denominator = (DECDIG_DBL)bdx.real->frac[0] * 100 + (DECDIG_DBL)(bdx.real->Prec >= 2 ? bdx.real->frac[1] : 0) * 100 / BIGDECIMAL_BASE; + inv0.real->frac[0] = (DECDIG)(numerator / denominator); + inv0.real->frac[1] = (DECDIG)((numerator % denominator) * (BIGDECIMAL_BASE / 100) / denominator * 100); + inv0.real->Prec = 2; + inv0.real->exponent = 1 - bdx.real->exponent; + VpNmlz(inv0.real); + RB_GC_GUARD(bdx.bigdecimal); + VALUE inv = inv0.bigdecimal; + + int bl = 1; + while (((size_t)1 << bl) < prec) bl++; + + for (int i = bl; i >= 0; i--) { + size_t n = (prec >> i) + 2; + if (n > prec) n = prec; + // Newton-Raphson iteration: inv_next = inv + inv * (1 - x * inv) + VALUE one_minus_x_inv = BigDecimal_sub2( + one, + BigDecimal_mult(BigDecimal_mult2(x, one, SIZET2NUM(n + 1)), inv), + SIZET2NUM(SIZET2NUM(n / 2)) + ); + inv = BigDecimal_add2( + inv, + BigDecimal_mult(inv, one_minus_x_inv), + SIZET2NUM(n) + ); + } + return inv; +} + +// Calculates divmod by multiplying approximate reciprocal of y +static void +divmod_by_inv_mul(VALUE x, VALUE y, VALUE inv, VALUE *res_div, VALUE *res_mod) { + VALUE div = BigDecimal_fix(BigDecimal_mult(x, inv)); + VALUE mod = BigDecimal_sub(x, BigDecimal_mult(div, y)); + while (RTEST(BigDecimal_lt(mod, INT2FIX(0)))) { + mod = BigDecimal_add(mod, y); + div = BigDecimal_sub(div, INT2FIX(1)); + } + while (RTEST(BigDecimal_ge(mod, y))) { + mod = BigDecimal_sub(mod, y); + div = BigDecimal_add(div, INT2FIX(1)); + } + *res_div = div; + *res_mod = mod; +} + +static void +slice_copy(DECDIG *dest, Real *src, size_t rshift, size_t length) { + ssize_t start = src->exponent - rshift - length; + if (start >= (ssize_t)src->Prec) return; + if (start < 0) { + dest -= start; + length += start; + start = 0; + } + size_t max_length = src->Prec - start; + memcpy(dest, src->frac + start, Min(length, max_length) * sizeof(DECDIG)); +} + +/* Calculates divmod using Newton-Raphson method. + * x and y must be a BigDecimal representing an integer value. + * + * To calculate with low cost, we need to split x into blocks and perform divmod for each block. + * x_digits = remaining_digits(<= y_digits) + block_digits * num_blocks + * + * Example: + * xxx_xxxxx_xxxxx_xxxxx(18 digits) / yyyyy(5 digits) + * remaining_digits = 3, block_digits = 5, num_blocks = 3 + * repeating xxxxx_xxxxxx.divmod(yyyyy) calculation 3 times. + * + * In each divmod step, dividend is at most (y_digits + block_digits) digits and divisor is y_digits digits. + * Reciprocal of y needs block_digits + 1 precision. + */ +static void +divmod_newton(VALUE x, VALUE y, VALUE *div_out, VALUE *mod_out) { + size_t x_digits = NUM2SIZET(BigDecimal_exponent(x)); + size_t y_digits = NUM2SIZET(BigDecimal_exponent(y)); + if (x_digits <= y_digits) x_digits = y_digits + 1; + + size_t n = x_digits / y_digits; + size_t block_figs = (x_digits - y_digits) / n / BIGDECIMAL_COMPONENT_FIGURES + 1; + size_t block_digits = block_figs * BIGDECIMAL_COMPONENT_FIGURES; + size_t num_blocks = (x_digits - y_digits + block_digits - 1) / block_digits; + size_t y_figs = (y_digits - 1) / BIGDECIMAL_COMPONENT_FIGURES + 1; + VALUE yinv = newton_raphson_inverse(y, block_digits + 1); + + BDVALUE divident = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (y_figs + block_figs)); + BDVALUE div_result = NewZeroWrap(1, BIGDECIMAL_COMPONENT_FIGURES * (num_blocks * block_figs + 1)); + BDVALUE bdx = GetBDValueMust(x); + + VALUE mod = BigDecimal_fix(BigDecimal_decimal_shift(x, SSIZET2NUM(-num_blocks * block_digits))); + for (ssize_t i = num_blocks - 1; i >= 0; i--) { + memset(divident.real->frac, 0, (y_figs + block_figs) * sizeof(DECDIG)); + + BDVALUE bdmod = GetBDValueMust(mod); + slice_copy(divident.real->frac, bdmod.real, 0, y_figs); + slice_copy(divident.real->frac + y_figs, bdx.real, i * block_figs, block_figs); + RB_GC_GUARD(bdmod.bigdecimal); + + VpSetSign(divident.real, 1); + divident.real->exponent = y_figs + block_figs; + divident.real->Prec = y_figs + block_figs; + VpNmlz(divident.real); + + VALUE div; + divmod_by_inv_mul(divident.bigdecimal, y, yinv, &div, &mod); + BDVALUE bddiv = GetBDValueMust(div); + slice_copy(div_result.real->frac + (num_blocks - i - 1) * block_figs, bddiv.real, 0, block_figs + 1); + RB_GC_GUARD(bddiv.bigdecimal); + } + VpSetSign(div_result.real, 1); + div_result.real->exponent = num_blocks * block_figs + 1; + div_result.real->Prec = num_blocks * block_figs + 1; + VpNmlz(div_result.real); + RB_GC_GUARD(bdx.bigdecimal); + RB_GC_GUARD(divident.bigdecimal); + RB_GC_GUARD(div_result.bigdecimal); + *div_out = div_result.bigdecimal; + *mod_out = mod; +} + +static VALUE +VpDivdNewtonInner(VALUE args_ptr) +{ + Real **args = (Real**)args_ptr; + Real *c = args[0], *r = args[1], *a = args[2], *b = args[3]; + BDVALUE a2, b2, c2, r2; + VALUE div, mod, a2_frac = Qnil; + size_t div_prec = c->MaxPrec - 1; + size_t base_prec = b->Prec; + + a2 = NewZeroWrap(1, a->Prec * BIGDECIMAL_COMPONENT_FIGURES); + b2 = NewZeroWrap(1, b->Prec * BIGDECIMAL_COMPONENT_FIGURES); + VpAsgn(a2.real, a, 1); + VpAsgn(b2.real, b, 1); + VpSetSign(a2.real, 1); + VpSetSign(b2.real, 1); + a2.real->exponent = base_prec + div_prec; + b2.real->exponent = base_prec; + + if ((ssize_t)a2.real->Prec > a2.real->exponent) { + a2_frac = BigDecimal_frac(a2.bigdecimal); + VpMidRound(a2.real, VP_ROUND_DOWN, 0); + } + divmod_newton(a2.bigdecimal, b2.bigdecimal, &div, &mod); + if (a2_frac != Qnil) mod = BigDecimal_add(mod, a2_frac); + + c2 = GetBDValueMust(div); + r2 = GetBDValueMust(mod); + VpAsgn(c, c2.real, VpGetSign(a) * VpGetSign(b)); + VpAsgn(r, r2.real, VpGetSign(a)); + AddExponent(c, a->exponent); + AddExponent(c, -b->exponent); + AddExponent(c, -div_prec); + AddExponent(r, a->exponent); + AddExponent(r, -base_prec - div_prec); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(a2.bigdecimal); + RB_GC_GUARD(c2.bigdecimal); + RB_GC_GUARD(r2.bigdecimal); + return Qnil; +} + +static VALUE +ensure_restore_prec_limit(VALUE limit) +{ + VpSetPrecLimit(NUM2SIZET(limit)); + return Qnil; +} + +static void +VpDivdNewton(Real *c, Real *r, Real *a, Real *b) +{ + Real *args[4] = {c, r, a, b}; + size_t pl = VpGetPrecLimit(); + VpSetPrecLimit(0); + // Ensure restoring prec limit because some methods used in VpDivdNewtonInner may raise an exception + rb_ensure(VpDivdNewtonInner, (VALUE)args, ensure_restore_prec_limit, SIZET2NUM(pl)); +} diff --git a/ext/bigdecimal/ntt.h b/ext/bigdecimal/ntt.h new file mode 100644 index 00000000..941f23f7 --- /dev/null +++ b/ext/bigdecimal/ntt.h @@ -0,0 +1,191 @@ +// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication + +#define NTT_PRIMITIVE_ROOT 17 +#define NTT_PRIME_BASE1 24 +#define NTT_PRIME_BASE2 26 +#define NTT_PRIME_BASE3 29 +#define NTT_PRIME_SHIFT 27 +#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1) +#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1) +#define MAX_NTT32_BITS 27 +#define NTT_DECDIG_BASE 1000000000 + +// Calculates base**ex % mod +static uint32_t +mod_pow(uint32_t base, uint32_t ex, uint32_t mod) { + uint32_t res = 1; + uint32_t bit = 1; + while (true) { + if (ex & bit) { + ex ^= bit; + res = ((uint64_t)res * base) % mod; + } + if (!ex) break; + base = ((uint64_t)base * base) % mod; + bit <<= 1; + } + return res; +} + +// Recursively performs butterfly operations of NTT +static void +ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) { + if (depth > 0) { + ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime); + } else { + tmp = input; + } + uint32_t size_half = (uint32_t)1 << (size_bits - 1); + uint32_t stride = (uint32_t)1 << (size_bits - depth - 1); + uint32_t n = size_half / stride; + uint32_t rn = 1, rm = prime - 1; + uint32_t idx = 0; + for (uint32_t i = 0; i < n; i++) { + uint32_t j = i * 2 * stride; + for (uint32_t k = 0; k < stride; k++, j++, idx++) { + uint32_t a = tmp[j], b = tmp[j + stride]; + output[idx] = (a + (uint64_t)rn * b) % prime; + output[idx + size_half] = (a + (uint64_t)rm * b) % prime; + } + rn = ((uint64_t)rn * r) % prime; + rm = ((uint64_t)rm * r) % prime; + } +} + +/* Perform NTT on input array. + * base, shift: Represent the prime number as (base << shift | 1) + * r_base: Primitive root of unity modulo prime + * size_bits: log2 of the size of the input array. Should be less or equal to shift + * input: input array of size (1 << size_bits) + */ +static void +ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) { + uint32_t size = (uint32_t)1 << size_bits; + uint32_t prime = ((uint32_t)base << shift) | 1; + + // rmax**(1 << shift) % prime == 1 + // r**size % prime == 1 + uint32_t rmax = mod_pow(r_base, base, prime); + uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime); + + if (dir < 0) r = mod_pow(r, prime - 2, prime); + ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime); + if (dir < 0) { + uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime); + for (uint32_t i = 0; i < size; i++) { + output[i] = ((uint64_t)output[i] * n_inv) % prime; + } + } +} + +/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3 + * c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3) + * Assume c <= 999999999**2*(1<<27) + */ +static inline void +mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) { + // Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3 + // [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2 + // DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3 + // 35002755423056150739595925972 = [1, 3489660916, 3113851359] + // 14584479687667766215746868453 = [0, 13, 1297437912] + // 37919651490985126265126719818 = [0, 0, 3373338954] + uint64_t c0 = mod1; + uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916; + uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3; + c2 += c1 / NTT_PRIME2; + c1 %= NTT_PRIME2; + c2 %= NTT_PRIME3; + // Base conversion. c fits in 3 digits. + c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[0] = c0 % NTT_DECDIG_BASE; + c0 /= NTT_DECDIG_BASE; + c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2; + c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1; + c1 /= NTT_DECDIG_BASE; + digits[1] = c0 % NTT_DECDIG_BASE; + digits[2] = (uint32_t)(c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1); +} + +/* + * NTT multiplication + * Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1) + */ +static void +ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) { + if (a_size < b_size) { + ntt_multiply(b_size, a_size, b, a, c); + return; + } + + int b_bits = 0; + while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++; + int ntt_size_bits = b_bits + 1; + if (ntt_size_bits > MAX_NTT32_BITS) { + rb_raise(rb_eArgError, "Multiply size too large"); + } + + // To calculate large_a * small_b faster, split into several batches. + uint32_t ntt_size = (uint32_t)1 << ntt_size_bits; + uint32_t batch_size = ntt_size - (uint32_t)b_size; + uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size); + + uint32_t *mem = ruby_xcalloc(sizeof(uint32_t), ntt_size * 9); + uint32_t *ntt1 = mem; + uint32_t *ntt2 = mem + ntt_size; + uint32_t *ntt3 = mem + ntt_size * 2; + uint32_t *tmp1 = mem + ntt_size * 3; + uint32_t *tmp2 = mem + ntt_size * 4; + uint32_t *tmp3 = mem + ntt_size * 5; + uint32_t *conv1 = mem + ntt_size * 6; + uint32_t *conv2 = mem + ntt_size * 7; + uint32_t *conv3 = mem + ntt_size * 8; + + // Calculate NTT for b in three primes. Result is reused for each batch of a. + memcpy(tmp1, b, b_size * sizeof(uint32_t)); + memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t)); + ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + + memset(c, 0, (a_size + b_size) * sizeof(uint32_t)); + for (uint32_t idx = 0; idx < batch_count; idx++) { + uint32_t len = idx == batch_count - 1 ? (uint32_t)a_size - idx * batch_size : batch_size; + memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t)); + memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t)); + // Calculate convolution for this batch in three primes + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1; + ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2; + ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1); + ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1); + for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3; + ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1); + + // Restore the original convolution value from three convolutions calculated in three primes. + // Each convolution value is maximum 999999999**2*(1<<27)/2 + for (uint32_t i = 0; i < ntt_size; i++) { + uint32_t dig[3]; + mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig); + // Maximum values of dig[0], dig[1], and dig[2] are 999999999, 999999999 and 67108863 respectively + // Maximum overlapped sum (considering overlaps between 2 batches) is less than 4134217722 + // so this sum doesn't overflow uint32_t. + for (int j = 0; j < 3; j++) { + // Index check: if dig[j] is non-zero, assign index is within valid range. + if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j]; + } + } + } + uint32_t carry = 0; + for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) { + uint32_t v = c[i] + carry; + c[i] = v % NTT_DECDIG_BASE; + carry = v / NTT_DECDIG_BASE; + } + ruby_xfree(mem); +} diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index 075df0b6..389aba77 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -13,6 +13,10 @@ def setup end end + def ntt_mult_available? + BASE_FIG == 9 + end + def test_vpmult assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321'))) assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321'))) @@ -21,6 +25,68 @@ def test_vpmult assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200"))) end + def test_nttmult + omit 'NTT multiplication is only available for 32-bit DECDIG' unless ntt_mult_available? + [*1..32].repeated_permutation(2) do |a, b| + x = BigDecimal(10 ** (BASE_FIG * a) / 7) + y = BigDecimal(10 ** (BASE_FIG * b) / 13) + assert_equal(x.to_i * y.to_i, x.nttmult(y)) + end + end + + def test_newton_inverse + xs = [BigDecimal(3), BigDecimal('123e50'), BigDecimal('13' * 44), BigDecimal('17' * 45), BigDecimal('19' * 46)] + %i[up half_up down].each do |rounding_mode| + BigDecimal.save_rounding_mode do + BigDecimal.mode(BigDecimal::ROUND_MODE, rounding_mode) + [*1..32, 50, 100, 200, 300].each do |prec| + xs.each do |x| + inv = x.newton_raphson_inverse(prec) + assert_in_delta(1, x * inv, BigDecimal("1e#{1 - prec}")) + + high_precision_inv = inv * (2 - x * inv) + expected_inv = high_precision_inv.mult(1, prec) + last_digit = BigDecimal("1e#{expected_inv.exponent - prec}") + assert_include([expected_inv - last_digit, expected_inv, expected_inv + last_digit], inv) + end + end + end + end + end + + def test_not_affected_by_limit + x_int = 123**135 + y_int = 135**123 + xy_int = x_int * y_int + mod_int = 111**111 + x = BigDecimal(x_int) + y = BigDecimal(y_int) + xy = BigDecimal(xy_int) + mod = BigDecimal(mod_int) + z = BigDecimal(xy_int + mod_int) + BigDecimal.save_limit do + BigDecimal.limit 3 + assert_equal(xy, x.vpmult(y)) + assert_equal(3, BigDecimal.limit) + if ntt_mult_available? + assert_equal(xy, x.nttmult(y)) + assert_equal(3, BigDecimal.limit) + end + + prec = (z.exponent - 1) / BASE_FIG - (y.exponent - 1) / BASE_FIG + 1 + assert_equal([x, mod], z.vpdivd(y, prec)) + assert_equal(3, BigDecimal.limit) + assert_equal([x, mod], z.vpdivd_newton(y, prec)) + assert_equal(3, BigDecimal.limit) + end + end + + def assert_vpdivd_equal(expected_divmod, x_y_n) + x, *args = x_y_n + assert_equal(expected_divmod, x.vpdivd(*args)) + assert_equal(expected_divmod, x.vpdivd_newton(*args)) + end + def test_vpdivd # a[0] > b[0] # XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z @@ -31,11 +97,11 @@ def test_vpdivd d3 = BigDecimal("4e#{BASE_FIG * 2}") + d2 d4 = BigDecimal("5e#{BASE_FIG}") + d3 d5 = BigDecimal(6) + d4 - assert_equal([d1, x1 - d1 * y], x1.vpdivd(y, 1)) - assert_equal([d2, x1 - d2 * y], x1.vpdivd(y, 2)) - assert_equal([d3, x1 - d3 * y], x1.vpdivd(y, 3)) - assert_equal([d4, x1 - d4 * y], x1.vpdivd(y, 4)) - assert_equal([d5, x1 - d5 * y], x1.vpdivd(y, 5)) + assert_vpdivd_equal([d1, x1 - d1 * y], [x1, y, 1]) + assert_vpdivd_equal([d2, x1 - d2 * y], [x1, y, 2]) + assert_vpdivd_equal([d3, x1 - d3 * y], [x1, y, 3]) + assert_vpdivd_equal([d4, x1 - d4 * y], [x1, y, 4]) + assert_vpdivd_equal([d5, x1 - d5 * y], [x1, y, 5]) # a[0] < b[0] # 00XX_XXYY_YYZZ_ZZ00 / 1111 #=> 0000_0X00_0Y00_0Z00 @@ -46,28 +112,28 @@ def test_vpdivd d3 = BigDecimal("4e#{2 * BASE_FIG + shift}") + d2 d4 = BigDecimal("5e#{BASE_FIG + shift}") + d3 d5 = BigDecimal("6e#{shift}") + d4 - assert_equal([0, x2], x2.vpdivd(y, 1)) - assert_equal([d1, x2 - d1 * y], x2.vpdivd(y, 2)) - assert_equal([d2, x2 - d2 * y], x2.vpdivd(y, 3)) - assert_equal([d3, x2 - d3 * y], x2.vpdivd(y, 4)) - assert_equal([d4, x2 - d4 * y], x2.vpdivd(y, 5)) - assert_equal([d5, x2 - d5 * y], x2.vpdivd(y, 6)) + assert_vpdivd_equal([0, x2], [x2, y, 1]) + assert_vpdivd_equal([d1, x2 - d1 * y], [x2, y, 2]) + assert_vpdivd_equal([d2, x2 - d2 * y], [x2, y, 3]) + assert_vpdivd_equal([d3, x2 - d3 * y], [x2, y, 4]) + assert_vpdivd_equal([d4, x2 - d4 * y], [x2, y, 5]) + assert_vpdivd_equal([d5, x2 - d5 * y], [x2, y, 6]) end def test_vpdivd_large_quotient_prec # 0001 / 0003 = 0000_3333_3333 - assert_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(1).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('0.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(1), BigDecimal(3), 10]) # 1000 / 0003 = 0333_3333_3333 - assert_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], BigDecimal(BASE / 10).vpdivd(BigDecimal(3), 10)) + assert_vpdivd_equal([BigDecimal('3' * (BASE_FIG - 1) + '.' + '3' * BASE_FIG * 9), BigDecimal("1e-#{9 * BASE_FIG}")], [BigDecimal(BASE / 10), BigDecimal(3), 10]) end def test_vpdivd_with_one x = BigDecimal('1234.2468000001234') - assert_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], x.vpdivd(BigDecimal(1), 1)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], (+x).vpdivd(BigDecimal(-1), 2)) - assert_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(+1), 2)) - assert_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], (-x).vpdivd(BigDecimal(-1), 2)) + assert_vpdivd_equal([BigDecimal('1234'), BigDecimal('0.2468000001234')], [x, BigDecimal(1), 1]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('+0.1234e-9')], [+x, BigDecimal(-1), 2]) + assert_vpdivd_equal([BigDecimal('-1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(+1), 2]) + assert_vpdivd_equal([BigDecimal('+1234.2468'), BigDecimal('-0.1234e-9')], [-x, BigDecimal(-1), 2]) end def test_vpdivd_precisions @@ -79,7 +145,7 @@ def test_vpdivd_precisions yn = (y.digits.size + BASE_FIG - 1) / BASE_FIG base = BASE ** (n - xn + yn - 1) div = BigDecimal((x * base / y).to_i) / base - assert_equal([div, x - y * div], BigDecimal(x).vpdivd(y, n)) + assert_vpdivd_equal([div, x - y * div], [BigDecimal(x), BigDecimal(y), n]) end end end @@ -92,7 +158,7 @@ def test_vpdivd_borrow x = y * (3 * BASE**4 + a * BASE**3 + b * BASE**2 + c * BASE + d) / BASE div = BigDecimal(x * BASE / y) / BASE mod = BigDecimal(x) - div * y - assert_equal([div, mod], BigDecimal(x).vpdivd(BigDecimal(y), 5)) + assert_vpdivd_equal([div, mod], [BigDecimal(x), BigDecimal(y), 5]) end end end @@ -104,22 +170,22 @@ def test_vpdivd_large_prec_divisor divy1_1 = BigDecimal(2) divy2_1 = BigDecimal(1) divy2_2 = BigDecimal('1.' + '9' * BASE_FIG) - assert_equal([divy1_1, x - y1 * divy1_1], x.vpdivd(y1, 1)) - assert_equal([divy2_1, x - y2 * divy2_1], x.vpdivd(y2, 1)) - assert_equal([divy2_2, x - y2 * divy2_2], x.vpdivd(y2, 2)) + assert_vpdivd_equal([divy1_1, x - y1 * divy1_1], [x, y1, 1]) + assert_vpdivd_equal([divy2_1, x - y2 * divy2_1], [x, y2, 1]) + assert_vpdivd_equal([divy2_2, x - y2 * divy2_2], [x, y2, 2]) end def test_vpdivd_intermediate_zero if BASE_FIG == 9 x = BigDecimal('123456789.246913578000000000123456789') y = BigDecimal('123456789') - assert_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], BigDecimal("2.000000000099999999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000000002000000000000000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.000000000049999999'), BigDecimal('1e-18')], [BigDecimal("2.000000000099999999"), 2, 3]) else x = BigDecimal('1234.246800001234') y = BigDecimal('1234') - assert_equal([BigDecimal('1.000200000001'), BigDecimal(0)], x.vpdivd(y, 4)) - assert_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], BigDecimal("2.00000999").vpdivd(2, 3)) + assert_vpdivd_equal([BigDecimal('1.000200000001'), BigDecimal(0)], [x, y, 4]) + assert_vpdivd_equal([BigDecimal('1.00000499'), BigDecimal('1e-8')], [BigDecimal("2.00000999"), 2, 3]) end end end