【筆記】大數乘法之Schönhage–Strassen演算法 (Java BigInteger原始碼)
阿新 • • 發佈:2018-12-20
BigInteger to/from uint[]
用uint[]來表示非負大數,其中陣列開頭是大數的最高32位,陣列結尾是大數最低32位。
/// <summary> /// <see cref="uint"/>陣列轉為非負大整數 /// </summary> private static BigInteger ValueOf(uint[] value) { var result = BigInteger.Zero; var array = value.SkipWhile(num => num == 0u).ToArray(); foreach (var num in array) { result <<= 32; result |= (num & 0xFFFF_FFFF); } return result; } /// <summary> /// 非負大整數轉為<see cref="uint"/>陣列 /// </summary> private static uint[] ToIntArray(BigInteger value) { var byteCount = value.GetByteCount(); var len = (int)Math.Ceiling(byteCount / 4d); var result = new uint[len]; for (var i = len - 1; i >= 0; --i) { result[i] = (uint)(value & 0xFFFF_FFFF); value >>= 32; } return result; }
MutableModFn
public class MutableModFn { long[] _digits; public int Length => _digits.Length; ///<summary> ///<para>Creates a <see cref="MutableModFn"/> number from a <see cref="long"/> array whose length ///2^(n - 6) + 1 for some n. The first element must be 0 or 1.</para> ///<para>The caller is trusted to pass in a valid array.</para> ///No copy of the array is made; its contents will reflect operations on the <see cref="MutableModFn"/> object. ///</summary> public MutableModFn(long[] digits) { this._digits = digits; } ///<summary> ///Creates a zero value. <paramref name="length"/> must be 2^(n - 6) + 1 for some n. ///</summary> public MutableModFn(int length) { _digits = new long[length]; } ///<summary> ///Copies this <see cref="MutableModFn"/>'s value into another <see cref="MutableModFn"/>. ///</summary> public void CopyTo(MutableModFn value) { Array.Copy(_digits, 0, value._digits, 0, _digits.Length); } ///<summary> ///Adds another <see cref="MutableModFn"/> to this number. ///</summary> public void Add(MutableModFn value) { var carry = false; for (var i = _digits.Length - 1; i >= 0; i--) { var sum = _digits[i] + value._digits[i]; if (carry) sum++; carry = (((ulong)sum >> 63) < ((ulong)_digits[i] >> 63) + ((ulong)value._digits[i] >> 63)); // carry if signBit(sum) < signBit(digits[i])+signBit(addend[i]) _digits[i] = sum; } // take a mod Fn by adding any remaining carry bit to the lowest bit; // since Fn is congruent to 1 (mod 2^n), it suffices to add 1 var j = _digits.Length - 1; while (carry && j >= 0) { var sum = _digits[j] + 1; _digits[j] = sum; carry = sum == 0; j--; } Reduce(); } ///<summary> ///Subtracts another <see cref="MutableModFn"/> to this number. ///</summary> public void Subtract(MutableModFn value) { var borrow = false; for (var i = _digits.Length - 1; i >= 0; i--) { var diff = _digits[i] - value._digits[i]; if (borrow) diff--; borrow = (((ulong)diff >> 63) > ((ulong)_digits[i] >> 63) - ((ulong)value._digits[i] >> 63)); // borrow if signBit(diff) > signBit(digits[i])-signBit(b.digits[i]) _digits[i] = diff; } // if we borrowed from the most significant long, subtract 2^2^n which is the same as adding 1 (mod Fn) if (borrow) { _digits[0]++; // undo borrow var i = _digits.Length - 1; var carry = true; while (carry && i >= 0) { var sum = _digits[i] + 1; _digits[i] = sum; carry = sum == 0; i--; } } } ///<summary> ///Multiplies another <see cref="MutableModFn"/> to this number. ///</summary> public void Multiply(MutableModFn value) { // if a=b=2^n, a*b=1 (mod Fn) if (_digits[0] == 1 && value._digits[0] == 1) { Array.Fill(_digits, 0); _digits[_digits.Length - 1] = 1; } // otherwise, a*b will fit into 2*2^n bits else { var intDigits = ToUIntArrayOdd(_digits); var aBigInt = (intDigits, true); var intBDigits = ToUIntArrayOdd(value._digits); var bBigInt = (intBDigits, true); var cInt = BigIntegerOperation.Multiply(aBigInt, bBigInt).Item1; cInt = cInt.SkipWhile(num => num == 0u).ToArray(); // zero-pad c to make it 2*2^n in length, and convert it to long[] var cIntPad = new uint[intDigits.Length - 1 + intBDigits.Length - 1]; Array.Copy(cInt, 0, cIntPad, cIntPad.Length - cInt.Length, cInt.Length); var c = ToLongArrayEven(cIntPad); // reduce c mod Fn which makes the first c.length/2-1 longs zero; return the others ReduceWide(c); Array.Copy(c, c.Length / 2 - 1, _digits, 0, c.Length / 2 + 1); } } ///<summary> ///Squares this number. ///</summary> public void Square() { // if a=2^n, a^2=1 (mod Fn) if (_digits[0] == 1) { Array.Fill(_digits, 0); _digits[_digits.Length - 1] = 1; } // otherwise, a^2 will fit into 2*2^n bits else { var intDigits = ToUIntArrayOdd(_digits); var cInt = BigIntegerOperation.Square((intDigits, true)).Item1; cInt = cInt.SkipWhile(num => num == 0u).ToArray(); // zero-pad cInt to make it 2*2^n bits in length, and convert it to long[] var cIntPad = new uint[2 * intDigits.Length - 2]; Array.Copy(cInt, 0, cIntPad, cIntPad.Length - cInt.Length, cInt.Length); var c = ToLongArrayEven(cIntPad); // reduce c mod Fn which makes the first c.length/2-1 longs zero; return the others ReduceWide(c); Array.Copy(c, c.Length / 2 - 1, _digits, 0, c.Length / 2 + 1); } } ///<summary> ///Reduces this number modulo Fn. <see cref="_digits"/>[0] will be 0 or 1. ///</summary> private void Reduce() { // Reduction modulo Fn is done by subtracting the most significant long from the least significant long var len = _digits.Length; var bi = _digits[0]; var diff = _digits[len - 1] - bi; var borrow = (((ulong)diff >> 63) > ((ulong)_digits[len - 1] >> 63) - ((ulong)bi >> 63)); // borrow if signBit(diff) > signBit(digits[len-1])-signBit(digits[0]) _digits[len - 1] = diff; _digits[0] = 0; // because we subtracted digits[0] from digits[len-1] if (borrow) { var i = len - 2; do { diff = _digits[i] - 1; _digits[i] = diff; borrow = diff == -1; i--; } while (borrow && i >= 0); } // if we borrowed from the most significant long, subtract 2^2^n which is the same as adding 1 (mod Fn) if (borrow) { var i = _digits.Length - 1; var carry = true; _digits[0] = 0; // increment digits[0] by 1 to make it 0 while (carry && i >= 0) { var sum = _digits[i] + 1; _digits[i] = sum; carry = sum == 0; i--; } } } ///<summary> ///Like <see cref="Reduce"/> but works on an array of length 2^(n+1). ///</summary> private static void ReduceWide(long[] a) { // Reduction modulo Fn is done by subtracting the upper half from the lower half var len = a.Length; var carry = false; for (var i = len - 1; i >= len / 2; i--) { var bi = a[i - len / 2]; var diff = a[i] - bi; if (carry) diff--; carry = (((ulong)diff >> 63) > ((ulong)a[i] >> 63) - ((ulong)bi >> 63)); // carry if signBit(diff) > signBit(a)-signBit(b) a[i] = diff; } for (var i = len / 2 - 1; i >= 0; i--) a[i] = 0; // if result is negative, add Fn; since Fn is congruent to 1 (mod 2^n), it suffices to add 1 if (carry) { var j = len - 1; do { var sum = a[j] + 1; a[j] = sum; carry = sum == 0; j--; if (j <= 0) break; } while (carry); } } ///<summary> ///Like <see cref="ReduceWide"/> but works on an an <see cref="uint"/> array. ///</summary> public static void Reduce(uint[] digits) { // Reduction modulo Fn is done by subtracting the most significant int from the least significant int var len = digits.Length; var bi = digits[0]; var diff = digits[len - 1] - bi; var borrow = ((diff >> 31) > (digits[len - 1] >> 31) - (bi >> 31)); // borrow if signBit(diff) > signBit(digits[len-1])-signBit(digits[0]) digits[len - 1] = diff; digits[0] = 0; // because we subtracted digits[0] from digits[len-1] if (borrow) { var i = len - 2; do { diff = digits[i] - 1; digits[i] = diff; borrow = diff == uint.MaxValue; i--; } while (borrow && i >= 0); } // if we borrowed from the most significant int, subtract 2^2^n which is the same as adding 1 (mod Fn) if (borrow) { var i = digits.Length - 1; var carry = true; digits[0] = 0; // increment digits[0] by 1 to make it 0 while (carry && i >= 0) { var sum = digits[i] + 1; digits[i] = sum; carry = sum == 0; i--; } } } ///<summary> ///<para>Multiplies this number by 2^-shiftAmtBits modulo 2^2^n + 1 where 2^n = <code>(digits.length-1)*64</code>.</para> ///<para>"Right" means towards the higher array indices and the lower bits.</para> ///<para>This is equivalent to extending the number to <code>2*(digits.length-1)</code> longs and cyclicly ///shifting to the right by<code>shiftAmt</code> bits.</para> ///The result is placed in the second argument. ///</summary> public void ShiftRight(int shiftAmtBits, MutableModFn b) { var len = _digits.Length; if (shiftAmtBits > 64 * (len - 1)) { ShiftLeft(64 * 2 * (len - 1) - shiftAmtBits, b); return; } var shiftAmtLongs = shiftAmtBits / 64; // number of longs to shift if (shiftAmtLongs > 0) { var borrow = false; var diff = 0L; // shift the digits that stay positive, except a[len-1] which is special for (int i = 1; i < len - shiftAmtLongs; i++) { diff = _digits[i]; if (borrow) diff--; b._digits[shiftAmtLongs + i] = diff; borrow = diff == -1 && borrow; } // subtract a[len-1] from a[0] diff = _digits[0] - _digits[len - 1]; if (borrow) { diff--; borrow = diff == -1; } else borrow = _digits[0] == 0 && _digits[len - 1] != 0; // a[0] can only be 0 or 1; if digits[0]!=0, digits[len-1]==0 b._digits[shiftAmtLongs] = diff; // using the fact that adding x*(Fn-1) is the same as subtracting x, // subtract digits shifted off the right, except for a[0] which is special for (var i = 1; i < shiftAmtLongs; i++) { b._digits[shiftAmtLongs - i] = -_digits[len - 1 - i]; if (borrow) b._digits[shiftAmtLongs - i]--; borrow = b._digits[shiftAmtLongs - i] != 0 || borrow; } // if we borrowed from the most significant long, add 1 to the overall number var carry = borrow; if (carry) { // increment b[0] and decrement b[len-1] b._digits[0] = 0; var i = len - 1; do { var sum = b._digits[i] + 1; b._digits[i] = sum; carry = sum == 0; i--; } while (carry && i >= 0); } else b._digits[0] = 0; } else Array.Copy(_digits, 0, b._digits, 0, len); var shiftAmtFrac = shiftAmtBits % 64; if (shiftAmtFrac != 0) { long bhi = b._digits[len - 1] << (64 - shiftAmtFrac); // do remaining digits b._digits[len - 1] = (long)((ulong)b._digits[len - 1] >> shiftAmtFrac); for (var i = len - 1; i > 0; i--) { b._digits[i] |= b._digits[i - 1] << (64 - shiftAmtFrac); b._digits[i - 1] = (long)((ulong)b._digits[i - 1] >> shiftAmtFrac); } // b[len-1] spills over into b[1] var diff = b._digits[1] - bhi; var borrow = (((ulong)diff >> 63) > ((ulong)b._digits[1] >> 63) - ((ulong)bhi >> 63)); // borrow if signBit(diff) > signBit(a)-signBit(b) b._digits[1] = diff; // if we borrowed from b[0], add 1 to the overall number var carry = borrow; if (carry) { // increment b[0] and decrement b[len-1] b._digits[0] = 0; var i = len - 1; do { var sum = b._digits[i] + 1; b._digits[i] = sum; carry = sum == 0; i--; } while (carry && i >= 0); } else b._digits[0] = 0; } } ///<summary> ///<para>Multiplies this number by 2^-shiftAmtBits modulo 2^2^n + 1 where 2^n = <code>(digits.length-1)*64</code>.</para> ///<para>"Left" means towards the higher array indices and the lower bits.</para> ///<para>This is equivalent to extending the number to <code>2*(digits.length-1)</code> longs and cyclicly ///shifting to the left by<code>shiftAmt</code> bits.</para> ///The result is placed in the second argument. ///</summary> public void ShiftLeft(int shiftAmtBits, MutableModFn b) { var len = _digits.Length; if (shiftAmtBits > 64 * (len - 1)) { ShiftRight(64 * 2 * (len - 1) - shiftAmtBits, b); return; } var shiftAmtLongs = shiftAmtBits / 64; // number of longs to shift if (shiftAmtLongs > 0) { var borrow = false; // using the fact that adding x*(Fn-1) is the same as subtracting x, // subtract digits shifted outside the [0..Fn-2] range, except for digits[0] which is special for (var i = 0; i < shiftAmtLongs; i++) { b._digits[len - 1 - i] = -_digits[shiftAmtLongs - i]; if (borrow) b._digits[len - 1 - i]--; borrow = b._digits[len - 1 - i] != 0 || borrow; } // subtract digits[0] from digits[len-1] (they overlap unless numElements=len-1) long diff; if (shiftAmtLongs < len - 1) diff = _digits[len - 1] - _digits[0]; else // no overlap diff = -_digits[0]; if (borrow) { diff--; borrow = diff == -1; } else borrow = _digits[0] == 1 && diff == -1; // digits[0] can only be 0 or 1 b._digits[len - 1 - shiftAmtLongs] = diff; // finally, shift the digits that stay in the [0..Fn-2] range for (var i = 1; i < len - shiftAmtLongs - 1; i++) { diff = _digits[len - 1 - i]; if (borrow) diff--; b._digits[len - 1 - shiftAmtLongs - i] = diff; borrow = diff == -1 && borrow; } // if we borrowed from the most significant long, add 1 to the overall number var carry = borrow; if (carry) { // increment b[0] and decrement b[len-1] b._digits[0] = 0; int i = len - 1; do { var sum = b._digits[i] + 1; b._digits[i] = sum; carry = sum == 0; i--; } while (carry && i >= 0); } else b._digits[0] = 0; } else Array.Copy(_digits, 0, b._digits, 0, len); var shiftAmtFrac = shiftAmtBits % 64; if (shiftAmtFrac != 0) { b._digits[0] <<= shiftAmtFrac; // no spill-over because 0<=digits[0]<=1 and shiftAmtFrac<=63 for (var i = 1; i < len; i++) { b._digits[i - 1] |= (long)((ulong)b._digits[i] >> (64 - shiftAmtFrac)); b._digits[i] <<= shiftAmtFrac; } } b.Reduce(); } ///<summary> ///<paramref name="digits"/>.Length must be an even number ///</summary> private static long[] ToLongArrayEven(uint[] digits) { var longDigits = new long[digits.Length / 2]; for (var i = 0; i < longDigits.Length; i++) longDigits[i] = (((long)digits[2 * i]) << 32) | (digits[2 * i + 1] & 0xFFFFFFFFL); return longDigits; } ///<summary> ///<paramref name="digits"/>.Length must be an odd number ///</summary> public static uint[] ToUIntArrayOdd(long[] digits) { var intDigits = new uint[digits.Length * 2 - 1]; intDigits[0] = (uint)digits[0]; for (var i = 1; i < digits.Length; i++) { intDigits[2 * i - 1] = (uint)((ulong)digits[i] >> 32); intDigits[2 * i] = (uint)(digits[i] & -1); } return intDigits; } ///<summary> ///<see cref="_digits"/>.Length must be an odd number ///</summary> public uint[] ToUIntArrayOdd() { return ToUIntArrayOdd(_digits); } }
Fast Fourier Transform
/// <summary> /// <para>Performs a modified DFT Fermat Number Transform on an array whose elements are <see cref="uint"/> arrays.</para> /// <para>The modification is that the first step is omitted because only the upper half of the result is needed.</para> /// <paramref name="value"/> is assumed to be the lower half of the full array and the upper half is assumed to be all zeros. /// </summary> private static void DFT(MutableModFn[] value, int omega) { // arrange the elements of A in a matrix roughly sqrt(A.length) by sqrt(A.length) in size var rows = 1 << ((31 - value.Length.NumberOfLeadingZeros()) >> 1); // number of rows var cols = value.Length / rows; // number of columns // step 1: perform an DFT on each column, that is, on the vector // A[colIdx], A[colIdx+cols], A[colIdx+2*cols], ..., A[colIdx+(rows-1)*cols]. for (var i = 0; i < cols; i++) DFTDirect(value, omega, rows, rows, cols, i, cols); // step 2: multiply by powers of omega ApplyDFTWeights(value, omega, rows, cols); // step 3 is built into step 1 by making the stride length a multiple of the row length // step 4: perform an DFT on each row, that is, on the vector // A[rowIdx*cols], A[rowIdx*cols+1], ..., A[rowIdx*cols+cols-1]. for (var i = 0; i < rows; i++) DFTDirect(value, omega, cols, 0, rows, i * cols, 1); } /// <summary> /// Performs a DFT on <paramref name="value"/>. /// This implementation uses the radix-4 technique which combines two levels of butterflies. /// </summary> private static void DFTDirect(MutableModFn[] value, int omega, int len, int expOffset, int expScale, int idxOffset, int stride) { var n = 31 - (2 * len).NumberOfLeadingZeros(); // multiply by 2 because we're doing a half DFT and we need the n that corresponds to the full DFT length var v = 1; // v starts at 1 rather than 0 for the same reason MutableModFn d = new MutableModFn(value[0].Length); var slen = len >> 1; while (slen > 1) { // slen = #consecutive coefficients for which the sign (add/sub) and x are constant for (var j = 0; j < len; j += (slen << 1)) { var x1 = GetDFTExponent(n, v + 1, j + expOffset, omega) * expScale; // for level v+2 var x2 = GetDFTExponent(n, v, j + expOffset, omega) * expScale; // for level v+1 var x3 = GetDFTExponent(n, v + 1, j + slen + expOffset, omega) * expScale; // for level v+2 // stride length = stride*slen elements var idx0 = stride * j + idxOffset; var idx1 = stride * j + stride * slen / 2 + idxOffset; var idx2 = idx0 + stride * slen; var idx3 = idx1 + stride * slen; for (var k = slen - 1; k >= 0; k -= 2) { // do level v+1 value[idx2].ShiftLeft(x2, d); value[idx0].CopyTo(value[idx2]); value[idx0].Add(d); value[idx2].Subtract(d); value[idx3].ShiftLeft(x2, d); value[idx1].CopyTo(value[idx3]); value[idx1].Add(d); value[idx3].Subtract(d); // do level v+2 value[idx1].ShiftLeft(x1, d); value[idx0].CopyTo(value[idx1]); value[idx0].Add(d); value[idx1].Subtract(d); value[idx3].ShiftLeft(x3, d); value[idx2].CopyTo(value[idx3]); value[idx2].Add(d); value[idx3].Subtract(d); idx0 += stride; idx1 += stride; idx2 += stride; idx3 += stride; } } v += 2; slen >>= 2; } // if there is an odd number of levels, do the remaining one now if (slen > 0) for (var j = 0; j < len; j += 2 * slen) { var x = GetDFTExponent(n, v, j + expOffset, omega) * expScale; var idx = stride * j + idxOffset; var idx2 = idx + stride * slen; // stride length = stride*slen elements for (var k = slen - 1; k >= 0; k--) { value[idx2].ShiftLeft(x, d); value[idx].CopyTo(value[idx2]); value[idx].Add(d); value[idx2].Subtract(d); idx += stride; idx2 += stride; } } } /// <summary> /// <para>Returns the power to which to raise omega in a DFT.</para> /// When <code>omega</code>=4, this method doubles the exponent so /// <code>omega</code> can be assumed always to be 2 in the /// <see cref="DFTDirect"/> and <see cref="IDFTDirect"/> methods. /// </summary> private static int GetDFTExponent(int n, int v, int idx, int omega) { // x = 2^(n-1-v) * s, where s is the v (out of n) high bits of idx in reverse order var x = (int)(ReverseBits((uint)idx >> (n - v)) >> (32 - v)); x <<= n - v - 1; // if omega=4, double the shift amount if (omega == 4) x <<= 1; return x; } /// <summary> /// Multiplies vector elements by powers of omega (aka twiddle factors). Used by Bailey's algorithm. /// </summary> private static void ApplyDFTWeights(MutableModFn[] value, int omega, int rows, int cols) { var v = 31 - rows.NumberOfLeadingZeros() + 1; for (var i = 0; i < rows; i++) for (var j = 0; j < cols; j++) { var idx = i * cols + j; var temp = new MutableModFn(value[idx].Length); var shiftAmt = GetBaileyShiftAmount(i, j, rows, v); if (omega == 4) shiftAmt *= 2; value[idx].ShiftLeft(shiftAmt, temp); temp.CopyTo(value[idx]); } } private static int GetBaileyShiftAmount(int i, int j, int rows, int v) { var iRev = (int)(ReverseBits((uint)(i + rows)) >> (32 - v)); return iRev * j; } /// <summary> /// <para>Performs a modified Inverse Fermat Number Transform on an array whose elements are <see cref="uint"/> arrays.</para> /// <para>The modification is that the last step (the one where the upper half is subtracted from the lower half) is omitted.</para> /// <paramref name="value"/> is assumed to be the upper half of the full array and the lower half is assumed to be all zeros. /// </summary> private static void IDFT(MutableModFn[] value, int omega) { // arrange the elements of A in a matrix roughly sqrt(A.length) by sqrt(A.length) in size var rows = 1 << ((31 - value.Length.NumberOfLeadingZeros()) >> 1); // number of rows var cols = value.Length / rows; // number of columns // step 1: perform an IDFT on each row, that is, on the vector // A[rowIdx*cols], A[rowIdx*cols+1], ..., A[rowIdx*cols+cols-1]. for (var i = 0; i < rows; i++) IDFTDirect(value, omega, cols, 0, rows, i * cols, 1); // step 2: multiply by powers of omega ApplyIDFTWeights(value, omega, rows, cols); // step 3 is built into step 4 by making the stride length a multiple of the row length // step 4: perform an IDFT on each column, that is, on the vector // A[colIdx], A[colIdx+cols], A[colIdx+2*cols], ..., A[colIdx+(rows-1)*cols]. for (var i = 0; i < cols; i++) IDFTDirect(value, omega, rows, rows, cols, i, cols); } /// <summary> /// This implementation uses the radix-4 technique which combines two levels of butterflies. /// </summary> private static void IDFTDirect(MutableModFn[] value, int omega, int len, int expOffset, int expScale, int idxOffset, int stride) { var n = 31 - (len << 1).NumberOfLeadingZeros(); // multiply by 2 because we're doing a half DFT and we need the n that corresponds to the full DFT length var v = 31 - len.NumberOfLeadingZeros(); var c = new MutableModFn(value[0].Length); var slen = 1; while (slen <= len / 4) { // slen = #consecutive coefficients for which the sign (add/sub) and x are constant for (var j = 0; j < len; j += 4 * slen) { var x1 = GetDFTExponent(n, v, j + expOffset, omega) * expScale + 1; // for level v var x2 = GetDFTExponent(n, v - 1, j + expOffset, omega) * expScale + 1; // for level v-1 var x3 = GetDFTExponent(n, v, j + slen * 2 + expOffset, omega) * expScale + 1; // for level v // stride length = stride*slen elements var idx0 = stride * j + idxOffset; var idx1 = stride * j + stride * slen + idxOffset; var idx2 = idx0 + stride * slen * 2; var idx3 = idx1 + stride * slen * 2; for (var k = slen - 1; k >= 0; k--) { // do level v value[idx0].CopyTo(c); value[idx0].Add(value[idx1]); value[idx0].ShiftRight(1, value[idx0]); c.Subtract(value[idx1]); c.ShiftRight(x1, value[idx1]); value[idx2].CopyTo(c); value[idx2].Add(value[idx3]); value[idx2].ShiftRight(1, value[idx2]); c.Subtract(value[idx3]); c.ShiftRight(x3, value[idx3]); // do level v-1 value[idx0].CopyTo(c); value[idx0].Add(value[idx2]); value[idx0].ShiftRight(1, value[idx0]); c.Subtract(value[idx2]); c.ShiftRight(x2, value[idx2]); value[idx1].CopyTo(c); value[idx1].Add(value[idx3]); value[idx1].ShiftRight(1, value[idx1]); c.Subtract(value[idx3]); c.ShiftRight(x2, value[idx3]); idx0 += stride; idx1 += stride; idx2 += stride; idx3 += stride; } } v -= 2; slen *= 4; } // if there is an odd number of levels, do the remaining one now if (slen <= len / 2) for (int j = 0; j < len; j += 2 * slen) { int x = GetDFTExponent(n, v, j + expOffset, omega) * expScale + 1; int idx = stride * j + idxOffset; int idx2 = idx + stride * slen; // stride length = stride*slen elements for (int k = slen - 1; k >= 0; k--) { value[idx].CopyTo(c); value[idx].Add(value[idx2]); value[idx].ShiftRight(1, value[idx]); c.Subtract(value[idx2]); c.ShiftRight(x, value[idx2]); idx += stride; idx2 += stride; } } } /// <summary> /// Divides vector elements by powers of omega (aka twiddle factors) /// </summary> private static void ApplyIDFTWeights(MutableModFn[] value, int omega, int rows, int cols) { var v = 31 - rows.NumberOfLeadingZeros() + 1; for (int i = 0; i < rows; i++) for (var j = 0; j < cols; j++) { var idx = i * cols + j; var temp = new MutableModFn(value[idx].Length); var shiftAmt = GetBaileyShiftAmount(i, j, rows, v); if (omega == 4) shiftAmt *= 2; value[idx].ShiftRight(shiftAmt, temp); temp.CopyTo(value[idx]); } } private static uint[][] ToIntArray(MutableModFn[] value) { var aInt = new uint[value.Length][]; for (var i = 0; i < value.Length; i++) aInt[i] = value[i].ToUIntArrayOdd(); return aInt; } /// <summary> /// Calls <see cref="MutableModFn.Multiply(MutableModFn)"/> for each element of /// <paramref name="left"/> and <paramref name="right"/> and places the result into <paramref name="left"/>. /// </summary> private static void MultiplyElements(MutableModFn[] left, MutableModFn[] right) { for (int i = 0; i < left.Length; i++) left[i].Multiply(right[i]); } /// <summary> /// Calls <see cref="MutableModFn.Square"/> for each element of /// <paramref name="value"/> and places the result into <paramref name="value"/>. /// </summary> private static void SquareElements(MutableModFn[] value) { for (int i = 0; i < value.Length; i++) value[i].Square(); } /// <summary> /// <para>Adds two numbers, <paramref name="left"/> and <paramref name="right"/>, /// after shifting <paramref name="right"/> by <paramref name="numElements"/> elements.</para> /// <para>both numbers are given as <see cref="uint"/> arrays and must be <b>positive</b> numbers /// (meaning they are interpreted as unsigned).</para> /// The result is returned in the first argument. /// If any elements of <paramref name="right"/> are shifted outside the valid range /// for <paramref name="left"/>, they are dropped. /// </summary> private static void AddShifted(uint[] left, uint[] right, int numElements) { var carry = false; var aIdx = left.Length - 1 - numElements; var bIdx = right.Length - 1; var i = Math.Min(aIdx, bIdx); while (i >= 0) { var ai = left[aIdx]; var sum = ai + right[bIdx]; if (carry) sum++; carry = ((sum >> 31) < (ai >> 31) + (right[bIdx] >> 31)); // carry if signBit(sum) < signBit(a)+signBit(b) left[aIdx] = sum; i--; aIdx--; bIdx--; } while (carry && aIdx >= 0) { left[aIdx]++; carry = left[aIdx] == 0; aIdx--; } } /// <summary> /// <para>Adds two positive numbers (meaning they are interpreted as unsigned) modulo 2^<paramref name="numBits"/></para> /// Both input values are given as <see cref="uint"/> arrays. /// The result is returned in the first argument. /// </summary> private static void AddModPow2(uint[] left, uint[] right, int numBits) { var numElements = (numBits + 31) >> 5; var carry = false; int i; var aIdx = left.Length - 1; var bIdx = right.Length - 1; for (i = numElements - 1; i >= 0; i--) { var sum = left[aIdx] + right[bIdx]; if (carry) sum++; carry = ((sum >> 31) < (left[aIdx] >> 31) + (right[bIdx] >> 31)); // carry if signBit(sum) < signBit(a)+signBit(b) left[aIdx] = sum; aIdx--; bIdx--; } if (numElements > 0) left[aIdx + 1] &= uint.MaxValue >> (32 - (numBits % 32)); for (; aIdx >= 0; aIdx--) left[aIdx] = 0; } /// <summary> /// <para>Subtracts two positive numbers (meaning they are interpreted as unsigned) modulo 2^<paramref name="numBits"/></para> /// Both input values are given as <see cref="uint"/> arrays. /// The result is returned in the first argument. /// </summary> private static void SubModPow2(uint[] left, uint[] right, int numBits) { var numElements = (numBits + 31) >> 5; var carry = false; int i; var aIdx = left.Length - 1; var bIdx = right.Length - 1; for (i = numElements - 1; i >= 0; i--) { var diff = left[aIdx] - right[bIdx]; if (carry) diff--; carry = ((diff >> 31) > (left[aIdx] >> 31) - (right[bIdx] >> 31)); // carry if signBit(diff) > signBit(a)-signBit(b) left[aIdx] = diff; aIdx--; bIdx--; } if (numElements > 0) left[aIdx + 1] &= uint.MaxValue >> (32 - (numBits % 32)); for (; aIdx >= 0; aIdx--) left[aIdx] = 0; }
Unsigned Schönhage–Strassen Multiplication
private static readonly int KARATSUBA_SQUARE_THRESHOLD = 128;
/// <summary>
/// SchönhageStrassen非負大數乘法,陣列第一個<see cref="uint"/>存放最高32位,最後一個<see cref="uint"/>存放最低32位。
/// </summary>
public static uint[] MultiplySchönhageStrassenNonegative(uint[] left, uint[] right)
{
if (left.Length < KARATSUBA_SQUARE_THRESHOLD ||
right.Length < KARATSUBA_SQUARE_THRESHOLD)
return MultiplyNonegative(left, right);
var square = left.SequenceEqual(right);
// set M to the number of binary digits in a or b, whichever is greater
var M = Math.Max(left.Length << 5, right.Length << 5);
// find the lowest m such that m>=log2(2M)
var m = 32 - (2 * M - 1 - 1).NumberOfLeadingZeros();
var n = (m >> 1) + 1;
// split a and b into pieces 1<<(n-1) bits long; assume n>=6 so pieces start and end at int boundaries
var even = (m & 1) == 0;
var numPieces = even ? 1 << n : 1 << (n + 1);
var pieceSize = 1 << (n - 1 - 5); // in ints
// zi mod 2^(n+2): build u and v from a and b, allocating 3n+5 bits in u and v per n+2 bits from a and b, resp.
var numPiecesA = (left.Length + pieceSize) / pieceSize;
var u = new uint[(numPiecesA * (3 * n + 5) + 31) / 32];
var uBitLength = 0;
for (var i = 0; i < numPiecesA && i * pieceSize < left.Length; i++)
{
AppendBits(u, uBitLength, left, i * pieceSize, n + 2);
uBitLength += 3 * n + 5;
}
uint[] gamma;
if (square)
gamma = SquareToomCook3((u, true)).Item1; // gamma = u * u
else
{
var numPiecesB = (right.Length + pieceSize) / pieceSize;
var v = new uint[(numPiecesB * (3 * n + 5) + 31) / 32];
var vBitLength = 0;
for (var i = 0; i < numPiecesB && i * pieceSize < right.Length; i++)
{
AppendBits(v, vBitLength, right, i * pieceSize, n + 2);
vBitLength += 3 * n + 5;
}
gamma = MultiplySchönhageStrassenNonegative(u, v); // gamma = u * v
}
var gammai = SplitBits(gamma, 3 * n + 5);
var halfNumPcs = numPieces / 2;
var zi = new uint[gammai.Length][];
for (var i = 0; i < gammai.Length; i++)
zi[i] = gammai[i];
for (var i = 0; i < gammai.Length - halfNumPcs; i++)
SubModPow2(zi[i], gammai[i + halfNumPcs], n + 2);
for (var i = 0; i < gammai.Length - 2 * halfNumPcs; i++)
AddModPow2(zi[i], gammai[i + 2 * halfNumPcs], n + 2);
for (var i = 0; i < gammai.Length - 3 * halfNumPcs; i++)
SubModPow2(zi[i], gammai[i + 3 * halfNumPcs], n + 2);
// zr mod Fn
var ai = Split(left, halfNumPcs, pieceSize, (1 << (n - 6)) + 1); // assume n>=6
MutableModFn[] bi = null;
if (!square)
bi = Split(right, halfNumPcs, pieceSize, (1 << (n - 6)) + 1);
var omega = even ? 4 : 2;
if (square)
{
DFT(ai, omega);
SquareElements(ai);
}
else
{
DFT(ai, omega);
DFT(bi, omega);
MultiplyElements(ai, bi);
}
var c = ai;
IDFT(c, omega);
var cInt = ToIntArray(c);
var z = new uint[(1 << (m - 5)) + 1];
// calculate zr mod Fm from zr mod Fn and zr mod 2^(n+2), then add to z
// note: z is an int[] rather than a MutableBigInteger because MBI.addShifted() seems to be much slower than BI.addShifted()
for (var i = 0; i < halfNumPcs; i++)
{
var eta = i >= zi.Length ? new uint[(n + 2 + 31) / 32] : zi[i];
// zi = delta = (zi-c[i]) % 2^(n+2)
SubModPow2(eta, cInt[i], n + 2);
// z += zr<<shift = [ci + delta*(2^2^n+1)] << [i*2^(n-1)]
var shift = i * (1 << (n - 1 - 5)); // assume n>=6
AddShifted(z, cInt[i], shift);
AddShifted(z, eta, shift);
AddShifted(z, eta, shift + (1 << (n - 5)));
}
MutableModFn.Reduce(z); // assume m>=5
return z;
}
/// <summary>
/// Reads <paramref name="rightBitLength"/> bits from <paramref name="right"/>, starting at array index
/// <paramref name="rightStart"/>, and copies them into <paramref name="left"/>, starting at bit
/// <paramref name="leftBitLength"/>.The result is returned in <paramref name="left"/>.
/// </summary>
private static void AppendBits(uint[] left, int leftBitLength, uint[] right, int rightStart, int rightBitLength)
{
var aIdx = left.Length - 1 - (leftBitLength >> 5);
var bit32 = leftBitLength % 32;
var bIdx = right.Length - 1 - rightStart;
var bIdxStop = bIdx - (rightBitLength >> 5);
while (bIdx > bIdxStop)
{
if (bit32 > 0)
{
left[aIdx] |= right[bIdx] << bit32;
aIdx--;
left[aIdx] = right[bIdx] >> (32 - bit32);
}
else
{
left[aIdx] = right[bIdx];
aIdx--;
}
bIdx--;
}
if (rightBitLength % 32 > 0)
{
var bi = right[bIdx];
bi &= uint.MaxValue >> (32 - rightBitLength % 32);
left[aIdx] |= bi << bit32;
if (bit32 + (rightBitLength % 32) > 32)
left[aIdx - 1] = bi >> (32 - bit32);
}
}
/// <summary>
/// Divides an <see cref="uint"/> array into pieces <paramref name="bitLength"/> bits long.
/// </summary>
private static uint[][] SplitBits(uint[] value, int bitLength)
{
var aIntIdx = value.Length - 1;
var aBitIdx = 0;
var numPieces = (value.Length * 32 + bitLength - 1) / bitLength;
var pieceLength = (bitLength + 31) / 32; // in ints
var b = new uint[numPieces][];
for (var i = 0; i < b.Length; i++)
{
b[i] = new uint[pieceLength];
var bitsRemaining = Math.Min(bitLength, value.Length * 32 - i * bitLength);
var bIntIdx = bitLength / 32;
if (bitLength % 32 == 0)
bIntIdx--;
var bBitIdx = 0;
while (bitsRemaining > 0)
{
var bitsToCopy = Math.Min(32 - aBitIdx, 32 - bBitIdx);
bitsToCopy = Math.Min(bitsRemaining, bitsToCopy);
var mask = value[aIntIdx] >> aBitIdx;
mask &= uint.MaxValue >> (32 - bitsToCopy);
mask <<= bBitIdx;
b[i][bIntIdx] |= mask;
bitsRemaining -= bitsToCopy;
aBitIdx += bitsToCopy;
if (aBitIdx >= 32)
{
aBitIdx -= 32;
aIntIdx--;
}
bBitIdx += bitsToCopy;
if (bBitIdx >= 32)
{
bBitIdx -= 32;
bIntIdx--;
}
}
}
return b;
}
/// <summary>
/// Splits an <see cref="uint"/> array into pieces of <paramref name="numPieces"/> <see cref="long"/>s each,
/// pads each piece to <paramref name="targetPieceSize"/> <see cref="long"/>s,
/// and wraps it in a <see cref="MutableModFn"/>
/// (this implies<paramref name="targetPieceSize"/> = 2^k + 1 for some k).
/// </summary>
private static MutableModFn[] Split(uint[] value, int numPieces, int sourcePieceSize, int targetPieceSize)
{
var ai = new MutableModFn[numPieces];
var aIdx = value.Length - sourcePieceSize;
var pieceIdx = 0;
while (aIdx >= 0)
{
var digits = new long[targetPieceSize];
for (var i = 0; i < sourcePieceSize; i += 2)
digits[targetPieceSize - sourcePieceSize / 2 + i / 2] = (((long)value[aIdx + i]) << 32) | (value[aIdx + i + 1] & 0xFFFFFFFFL);
ai[pieceIdx] = new MutableModFn(digits);
aIdx -= sourcePieceSize;
pieceIdx++;
}
var digits2 = new long[targetPieceSize];
if ((value.Length % sourcePieceSize) % 2 == 0)
for (int i = 0; i < value.Length % sourcePieceSize; i += 2)
digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2] = (((long)value[i]) << 32) | (value[i + 1] & 0xFFFFFFFFL);
else
{
for (int i = 0; i < value.Length % sourcePieceSize - 2; i += 2)
{
digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2] = ((long)value[i + 1]) << 32;
digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2 - 1] |= value[i] & 0xFFFFFFFFL;
}
// the remaining half-long
digits2[targetPieceSize - 1] |= value[value.Length % sourcePieceSize - 1] & 0xFFFFFFFFL;
}
ai[pieceIdx] = new MutableModFn(digits2);
while (++pieceIdx < numPieces)
ai[pieceIdx] = new MutableModFn(targetPieceSize);
return ai;
}
BigInteger to/from (uint[], bool)
用(uint[], bool)來表示有符號大數,其中uint[]是大數的絕對值,bool為false時是負數。
/// <summary>
/// (<see cref="uint"/>[], <see cref="bool"/>) to <see cref="BigInteger"/>
/// </summary>
private BigInteger ValueOf((uint[], bool) value)
{
var result = BigInteger.Zero;
var array = value.Item1.SkipWhile(num => num == 0u).ToArray();
foreach (var num in array)
{
result <<= 32;
result |= (num & 0xFFFF_FFFF);
}
return value.Item2 ? result : -result;
}
/// <summary>
/// <see cref="BigInteger"/> to (<see cref="uint"/>[], <see cref="bool"/>)
/// </summary>
private (uint[], bool) ToTuple(BigInteger value)
{
var positive = BigInteger.Abs(value);
var byteCount = positive.GetByteCount();
var len = (int)Math.Ceiling(byteCount / 4d);
var result = new uint[len];
for (var i = len - 1; i >= 0; --i)
{
result[i] = (uint)(positive & 0xFFFF_FFFF);
positive >>= 32;
}
return (result, value >= 0);
}
Signed Schönhage–Strassen Multiplication
/// <summary>
/// SchönhageStrassen乘法,陣列第一個<see cref="uint"/>存放最高32位,最後一個<see cref="uint"/>存放最低32位。
/// </summary>
public static (uint[], bool) MultiplySchönhageStrassen((uint[], bool) left, (uint[], bool) right)
{
if (IsZero(left))
return left;
if (IsZero(right))
return right;
if (IsAbsOne(left))
return (right.Item1, right.Item2 == left.Item2);
if (IsAbsOne(right))
return (left.Item1, left.Item2 == right.Item2);
return (MultiplySchönhageStrassenNonegative(left.Item1, right.Item1), left.Item2 == right.Item2);
}
測試
[TestMethod]
public void MultiplySchönhageStrassenNonegativeTest()
{
for (var i = 0; i < 100; ++i)
{
var left = BigInteger.Abs(RandomBigInteger());
var right = BigInteger.Abs(RandomBigInteger());
var test = MultiplySchönhageStrassenNonegative(ToIntArray(left), ToIntArray(right));
var expected = left * right;
Assert.AreEqual(expected, ValueOf(test));
}
}
[TestMethod]
public void MultiplySchönhageStrassenTest()
{
for (var i = 0; i < 100; ++i)
{
var left = RandomBigInteger();
var right = RandomBigInteger();
var test = MultiplySchönhageStrassen(ToTuple(left), ToTuple(right));
var expected = left * right;
Assert.AreEqual(expected, ValueOf(test));
}
}
private BigInteger RandomBigInteger()
{
var ran = new Random(Guid.NewGuid().GetHashCode());
var bytes = new byte[ran.Next(300, 500)];
ran.NextBytes(bytes);
return new BigInteger(bytes);
}