Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static BigInteger computePowerOfTen(NavigableMap<Integer, BigInteger> powersOfTe
if (floorN == n) {
return floorEntry.getValue();
} else {
return FftMultiplier.multiply(floorEntry.getValue(), computePowerOfTen(powersOfTen, n - floorN));
return FftMultiplier.multiply(floorEntry.getValue(), computePowerOfTen(powersOfTen, n - floorN), n - floorN);
}
}
return FIVE.pow(n).shiftLeft(n);
Expand All @@ -80,7 +80,7 @@ static BigInteger computeTenRaisedByNFloor16Recursive(NavigableMap<Integer, BigI
diffValue = computeTenRaisedByNFloor16Recursive(powersOfTen, diff);
powersOfTen.put(diff, diffValue);
}
return FftMultiplier.multiply(floorValue, diffValue);
return FftMultiplier.multiply(floorValue, diffValue, diff);
}

static NavigableMap<Integer, BigInteger> createPowersOfTenFloor16Map() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
package ch.randelshofer.fastdoubleparser;

import java.math.BigInteger;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReferenceArray;

import static ch.randelshofer.fastdoubleparser.FastDoubleMath.fastScalb;
import static ch.randelshofer.fastdoubleparser.FastDoubleSwar.fma;
Expand Down Expand Up @@ -48,7 +51,12 @@ class FftMultiplier {
/**
* for FFTs of length up to 2^19
*/
private static final int ROOTS_CACHE2_SIZE = 20;
static final int ROOTS2_CACHE_SIZE = 20;

private static final Map<Long, ComplexVector> FFT_POWER_OF_TEN_CACHE2 = new ConcurrentHashMap<>();
private static final Map<Long, ComplexVector> FFT_POWER_OF_TEN_CACHE3 = new ConcurrentHashMap<>();


/**
* The threshold value for using 3-way Toom-Cook multiplication.
*/
Expand All @@ -58,13 +66,13 @@ class FftMultiplier {
* elements representing all (2^(k+2))-th roots between 0 and pi/2.
* Used for FFT multiplication.
*/
private volatile static ComplexVector[] ROOTS2_CACHE = new ComplexVector[ROOTS_CACHE2_SIZE];
private final static AtomicReferenceArray<ComplexVector> ROOTS2_CACHE = new AtomicReferenceArray<>(ROOTS2_CACHE_SIZE);
/**
* Sets of complex roots of unity. The set at index k contains 3*2^k
* elements representing all (3*2^(k+2))-th roots between 0 and pi/2.
* Used for FFT multiplication.
*/
private volatile static ComplexVector[] ROOTS3_CACHE = new ComplexVector[ROOTS3_CACHE_SIZE];
private final static AtomicReferenceArray<ComplexVector> ROOTS3_CACHE = new AtomicReferenceArray<>(ROOTS3_CACHE_SIZE);

/**
* Returns the maximum number of bits that one double precision number can fit without
Expand Down Expand Up @@ -351,11 +359,11 @@ static BigInteger fromFftVector(ComplexVector fftVec, int signum, int bitsPerFft
private static ComplexVector[] getRootsOfUnity2(int logN) {
ComplexVector[] roots = new ComplexVector[logN + 1];
for (int i = logN; i >= 0; i -= 2) {
if (i < ROOTS_CACHE2_SIZE) {
if (ROOTS2_CACHE[i] == null) {
ROOTS2_CACHE[i] = calculateRootsOfUnity(1 << i);
if (i < ROOTS2_CACHE_SIZE) {
if (ROOTS2_CACHE.get(i) == null) {
ROOTS2_CACHE.set(i, calculateRootsOfUnity(1 << i));
}
roots[i] = ROOTS2_CACHE[i];
roots[i] = ROOTS2_CACHE.get(i);
} else {
roots[i] = calculateRootsOfUnity(1 << i);
}
Expand All @@ -371,10 +379,10 @@ private static ComplexVector[] getRootsOfUnity2(int logN) {
*/
private static ComplexVector getRootsOfUnity3(int logN) {
if (logN < ROOTS3_CACHE_SIZE) {
if (ROOTS3_CACHE[logN] == null) {
ROOTS3_CACHE[logN] = calculateRootsOfUnity(3 << logN);
if (ROOTS3_CACHE.get(logN) == null) {
ROOTS3_CACHE.set(logN, calculateRootsOfUnity(3 << logN));
}
return ROOTS3_CACHE[logN];
return ROOTS3_CACHE.get(logN);
} else {
return calculateRootsOfUnity(3 << logN);
}
Expand Down Expand Up @@ -536,6 +544,10 @@ private static void ifftMixedRadix(ComplexVector a, ComplexVector[] roots2, Comp
* performance when {@code a == b}.
*/
static BigInteger multiply(BigInteger a, BigInteger b) {
return multiply(a, b, -1);
}

static BigInteger multiply(BigInteger a, BigInteger b, int powerOfTen) {
if (b.signum() == 0 || a.signum() == 0) {
return BigInteger.ZERO;
}
Expand All @@ -554,7 +566,7 @@ static BigInteger multiply(BigInteger a, BigInteger b) {
if (xlen > TOOM_COOK_THRESHOLD
&& ylen > TOOM_COOK_THRESHOLD
&& (xlen > FFT_THRESHOLD || ylen > FFT_THRESHOLD)) {
return multiplyFft(a, b);
return multiplyFft(a, b, powerOfTen);
}
return a.multiply(b);
}
Expand Down Expand Up @@ -602,6 +614,9 @@ static BigInteger multiply(BigInteger a, BigInteger b) {
* @return a*b
*/
static BigInteger multiplyFft(BigInteger a, BigInteger b) {
return multiplyFft(a, b, -1);
}
private static BigInteger multiplyFft(BigInteger a, BigInteger b, int powerOfTen) {
int signum = a.signum() * b.signum();
byte[] aMag = (a.signum() < 0 ? a.negate() : a).toByteArray();
byte[] bMag = (b.signum() < 0 ? b.negate() : b).toByteArray();
Expand All @@ -613,35 +628,61 @@ static BigInteger multiplyFft(BigInteger a, BigInteger b) {
// Use a 2^n or 3*2^n transform, whichever is shortest
int fftLen2 = 1 << (logFFTLen); // rounded to 2^n
int fftLen3 = fftLen2 * 3 / 4; // rounded to 3*2^n
ComplexVector aVec;
ComplexVector weights;
if (fftLen < fftLen3 && logFFTLen > 3) {
ComplexVector[] roots2 = getRootsOfUnity2(logFFTLen - 2); // roots for length fftLen/3 which is a power of two
ComplexVector weights = getRootsOfUnity3(logFFTLen - 2);
ComplexVector[] roots = getRootsOfUnity2(logFFTLen - 2); // roots for length fftLen/3 which is a power of two
weights = getRootsOfUnity3(logFFTLen - 2);
ComplexVector twiddles = getRootsOfUnity3(logFFTLen - 4);
ComplexVector aVec = toFftVector(aMag, fftLen3, bitsPerPoint);

aVec = toFftVector(aMag, fftLen3, bitsPerPoint);
aVec.applyWeights(weights);
fftMixedRadix(aVec, roots2, twiddles);
ComplexVector bVec = toFftVector(bMag, fftLen3, bitsPerPoint);
bVec.applyWeights(weights);
fftMixedRadix(bVec, roots2, twiddles);
fftMixedRadix(aVec, roots, twiddles);

ComplexVector bVec = null;
if (powerOfTen != -1) {
bVec = FFT_POWER_OF_TEN_CACHE3.get(cacheFftIndex(powerOfTen, fftLen3));
}
if (bVec == null) {
bVec = toFftVector(bMag, fftLen3, bitsPerPoint);
bVec.applyWeights(weights);
fftMixedRadix(bVec, roots, twiddles);
if (powerOfTen != -1) {
FFT_POWER_OF_TEN_CACHE3.put(cacheFftIndex(powerOfTen, fftLen3), bVec);
}
}
aVec.multiplyPointwise(bVec);
ifftMixedRadix(aVec, roots2, twiddles);
aVec.applyInverseWeights(weights);
return fromFftVector(aVec, signum, bitsPerPoint);
ifftMixedRadix(aVec, roots, twiddles);
} else {
ComplexVector[] roots = getRootsOfUnity2(logFFTLen);
ComplexVector aVec = toFftVector(aMag, fftLen2, bitsPerPoint);
aVec.applyWeights(roots[logFFTLen]);
weights = roots[logFFTLen];
aVec = toFftVector(aMag, fftLen2, bitsPerPoint);
aVec.applyWeights(weights);
fft(aVec, roots);
ComplexVector bVec = toFftVector(bMag, fftLen2, bitsPerPoint);
bVec.applyWeights(roots[logFFTLen]);
fft(bVec, roots);

ComplexVector bVec = null;
if (powerOfTen != -1) {
bVec = FFT_POWER_OF_TEN_CACHE2.get(cacheFftIndex(powerOfTen, fftLen2));
}
if (bVec == null) {
bVec = toFftVector(bMag, fftLen2, bitsPerPoint);
bVec.applyWeights(weights);
fft(bVec, roots);
if (powerOfTen != -1) {
FFT_POWER_OF_TEN_CACHE2.put(cacheFftIndex(powerOfTen, fftLen2), bVec);// somewhere else index computation?
}
}

aVec.multiplyPointwise(bVec);
ifft(aVec, roots);
aVec.applyInverseWeights(roots[logFFTLen]);
return fromFftVector(aVec, signum, bitsPerPoint);
}
aVec.applyInverseWeights(weights);
return fromFftVector(aVec, signum, bitsPerPoint);
}

private static long cacheFftIndex(int power, int fftLength) {
return power + fftLength * 4_294_967_296L;
}
/**
* Returns a BigInteger whose value is {@code (this<sup>2</sup>)}.
*
Expand All @@ -664,29 +705,27 @@ static BigInteger squareFft(BigInteger a) {
// Use a 2^n or 3*2^n transform, whichever is shorter
int fftLen2 = 1 << (logFFTLen); // rounded to 2^n
int fftLen3 = fftLen2 * 3 / 4; // rounded to 3*2^n
ComplexVector vec, weights;
if (fftLen < fftLen3) {
fftLen = fftLen3;
ComplexVector vec = toFftVector(mag, fftLen, bitsPerPoint);
vec = toFftVector(mag, fftLen3, bitsPerPoint);
ComplexVector[] roots2 = getRootsOfUnity2(logFFTLen - 2); // roots for length fftLen/3 which is a power of two
ComplexVector weights = getRootsOfUnity3(logFFTLen - 2);
ComplexVector twiddles = getRootsOfUnity3(logFFTLen - 4);
weights = getRootsOfUnity3(logFFTLen - 2);
vec.applyWeights(weights);
fftMixedRadix(vec, roots2, twiddles);
vec.squarePointwise();
ifftMixedRadix(vec, roots2, twiddles);
vec.applyInverseWeights(weights);
return fromFftVector(vec, 1, bitsPerPoint);
} else {
fftLen = fftLen2;
ComplexVector vec = toFftVector(mag, fftLen, bitsPerPoint);
ComplexVector[] roots = getRootsOfUnity2(logFFTLen);
vec.applyWeights(roots[logFFTLen]);
fft(vec, roots);
vec = toFftVector(mag, fftLen2, bitsPerPoint);
ComplexVector[] roots2 = getRootsOfUnity2(logFFTLen);
weights = roots2[logFFTLen];
vec.applyWeights(weights);
fft(vec, roots2);
vec.squarePointwise();
ifft(vec, roots);
vec.applyInverseWeights(roots[logFFTLen]);
return fromFftVector(vec, 1, bitsPerPoint);
ifft(vec, roots2);
}
vec.applyInverseWeights(weights);
return fromFftVector(vec, 1, bitsPerPoint);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ BigDecimal valueOfBigDecimalString(byte[] str, int integerPartIndex, int decimal
significand = fractionalPart;
} else {
BigInteger integerFactor = computePowerOfTen(powersOfTen, fractionDigitsCount);
significand = FftMultiplier.multiply(integerPart, integerFactor).add(fractionalPart);
significand = FftMultiplier.multiply(integerPart, integerFactor, fractionDigitsCount).add(fractionalPart);
}
} else {
significand = integerPart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ private BigDecimal valueOfBigDecimalString(char[] str, int integerPartIndex, int
significand = fractionalPart;
} else {
BigInteger integerFactor = computePowerOfTen(powersOfTen, integerExponent);
significand = FftMultiplier.multiply(integerPart, integerFactor).add(fractionalPart);
significand = FftMultiplier.multiply(integerPart, integerFactor, integerExponent).add(fractionalPart);
}
} else {
significand = integerPart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private BigDecimal valueOfBigDecimalString(CharSequence str, int integerPartInde
significand = fractionalPart;
} else {
BigInteger integerFactor = computePowerOfTen(powersOfTen, fractionDigitsCount);
significand = FftMultiplier.multiply(integerPart, integerFactor).add(fractionalPart);
significand = FftMultiplier.multiply(integerPart, integerFactor, fractionDigitsCount).add(fractionalPart);
}
} else {
significand = integerPart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ static BigInteger parseDigitsRecursive(byte[] str, int from, int to, Map<Integer
BigInteger high = parseDigitsRecursive(str, from, mid, powersOfTen);
BigInteger low = parseDigitsRecursive(str, mid, to, powersOfTen);

//high = high.multiply(powersOfTen.get(to - mid));
high = FftMultiplier.multiply(high, powersOfTen.get(to - mid));
high = FftMultiplier.multiply(high, powersOfTen.get(to - mid), to - mid);
return low.add(high);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static BigInteger parseDigitsRecursive(char[] str, int from, int to, Map<Integer
BigInteger high = parseDigitsRecursive(str, from, mid, powersOfTen);
BigInteger low = parseDigitsRecursive(str, mid, to, powersOfTen);

high = FftMultiplier.multiply(high, powersOfTen.get(to - mid));
high = FftMultiplier.multiply(high, powersOfTen.get(to - mid), to - mid);
return low.add(high);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ static BigInteger parseDigitsRecursive(CharSequence str, int from, int to, Map<I
BigInteger high = parseDigitsRecursive(str, from, mid, powersOfTen);
BigInteger low = parseDigitsRecursive(str, mid, to, powersOfTen);

//high = high.multiply(powersOfTen.get(to - mid));
high = FftMultiplier.multiply(high, powersOfTen.get(to - mid));
high = FftMultiplier.multiply(high, powersOfTen.get(to - mid), to - mid);
return low.add(high);
}
}