1. 程式人生 > >【筆記】大數乘法之Schönhage–Strassen演算法 (Java BigInteger原始碼)

【筆記】大數乘法之Schönhage–Strassen演算法 (Java BigInteger原始碼)

BigInteger to/from uint[]

用uint[]來表示非負大數,其中陣列開頭是大數的最高32位,陣列結尾是大數最低32位。

/// <summary>
/// <see cref="uint"/>陣列轉為非負大整數
/// </summary>
private static BigInteger ValueOf(uint[] value)
{
    var result = BigInteger.Zero;
    var array = value.SkipWhile(num => num == 0u).ToArray();
    foreach (var num in array)
    {
        result <<= 32;
        result |= (num & 0xFFFF_FFFF);
    }

    return result;
}

/// <summary>
/// 非負大整數轉為<see cref="uint"/>陣列
/// </summary>
private static uint[] ToIntArray(BigInteger value)
{
    var byteCount = value.GetByteCount();
    var len = (int)Math.Ceiling(byteCount / 4d);
    var result = new uint[len];
    for (var i = len - 1; i >= 0; --i)
    {
        result[i] = (uint)(value & 0xFFFF_FFFF);
        value >>= 32;
    }

    return result;
}

MutableModFn

public class MutableModFn
{
    long[] _digits;

    public int Length => _digits.Length;

    ///<summary>
    ///<para>Creates a <see cref="MutableModFn"/> number from a <see cref="long"/> array whose length
    ///2^(n - 6) + 1 for some n. The first element must be 0 or 1.</para>
    ///<para>The caller is trusted to pass in a valid array.</para>
    ///No copy of the array is made; its contents will reflect operations on the <see cref="MutableModFn"/> object.
    ///</summary>
    public MutableModFn(long[] digits)
    {
        this._digits = digits;
    }

    ///<summary>
    ///Creates a zero value. <paramref name="length"/> must be 2^(n - 6) + 1 for some n.
    ///</summary>
    public MutableModFn(int length)
    {
        _digits = new long[length];
    }

    ///<summary>
    ///Copies this <see cref="MutableModFn"/>'s value into another <see cref="MutableModFn"/>.
    ///</summary>
    public void CopyTo(MutableModFn value)
    {
        Array.Copy(_digits, 0, value._digits, 0, _digits.Length);
    }

    ///<summary>
    ///Adds another <see cref="MutableModFn"/> to this number.
    ///</summary>
    public void Add(MutableModFn value)
    {
        var carry = false;
        for (var i = _digits.Length - 1; i >= 0; i--)
        {
            var sum = _digits[i] + value._digits[i];
            if (carry)
                sum++;
            carry = (((ulong)sum >> 63) < ((ulong)_digits[i] >> 63) + ((ulong)value._digits[i] >> 63));   // carry if signBit(sum) < signBit(digits[i])+signBit(addend[i])
            _digits[i] = sum;
        }

        // take a mod Fn by adding any remaining carry bit to the lowest bit;
        // since Fn is congruent to 1 (mod 2^n), it suffices to add 1
        var j = _digits.Length - 1;
        while (carry && j >= 0)
        {
            var sum = _digits[j] + 1;
            _digits[j] = sum;
            carry = sum == 0;
            j--;
        }

        Reduce();
    }

    ///<summary>
    ///Subtracts another <see cref="MutableModFn"/> to this number.
    ///</summary>
    public void Subtract(MutableModFn value)
    {
        var borrow = false;
        for (var i = _digits.Length - 1; i >= 0; i--)
        {
            var diff = _digits[i] - value._digits[i];
            if (borrow)
                diff--;
            borrow = (((ulong)diff >> 63) > ((ulong)_digits[i] >> 63) - ((ulong)value._digits[i] >> 63));   // borrow if signBit(diff) > signBit(digits[i])-signBit(b.digits[i])
            _digits[i] = diff;
        }

        // if we borrowed from the most significant long, subtract 2^2^n which is the same as adding 1 (mod Fn)
        if (borrow)
        {
            _digits[0]++;   // undo borrow
            var i = _digits.Length - 1;
            var carry = true;
            while (carry && i >= 0)
            {
                var sum = _digits[i] + 1;
                _digits[i] = sum;
                carry = sum == 0;
                i--;
            }
        }
    }

    ///<summary>
    ///Multiplies another <see cref="MutableModFn"/> to this number.
    ///</summary>
    public void Multiply(MutableModFn value)
    {
        // if a=b=2^n, a*b=1 (mod Fn)
        if (_digits[0] == 1 && value._digits[0] == 1)
        {
            Array.Fill(_digits, 0);
            _digits[_digits.Length - 1] = 1;
        }
        // otherwise, a*b will fit into 2*2^n bits
        else
        {
            var intDigits = ToUIntArrayOdd(_digits);
            var aBigInt = (intDigits, true);
            var intBDigits = ToUIntArrayOdd(value._digits);
            var bBigInt = (intBDigits, true);
            var cInt = BigIntegerOperation.Multiply(aBigInt, bBigInt).Item1;
            cInt = cInt.SkipWhile(num => num == 0u).ToArray();
            // zero-pad c to make it 2*2^n in length, and convert it to long[]
            var cIntPad = new uint[intDigits.Length - 1 + intBDigits.Length - 1];
            Array.Copy(cInt, 0, cIntPad, cIntPad.Length - cInt.Length, cInt.Length);
            var c = ToLongArrayEven(cIntPad);
            // reduce c mod Fn which makes the first c.length/2-1 longs zero; return the others
            ReduceWide(c);
            Array.Copy(c, c.Length / 2 - 1, _digits, 0, c.Length / 2 + 1);
        }
    }

    ///<summary>
    ///Squares this number.
    ///</summary>
    public void Square()
    {
        // if a=2^n, a^2=1 (mod Fn)
        if (_digits[0] == 1)
        {
            Array.Fill(_digits, 0);
            _digits[_digits.Length - 1] = 1;
        }
        // otherwise, a^2 will fit into 2*2^n bits
        else
        {
            var intDigits = ToUIntArrayOdd(_digits);
            var cInt = BigIntegerOperation.Square((intDigits, true)).Item1;
            cInt = cInt.SkipWhile(num => num == 0u).ToArray();
            // zero-pad cInt to make it 2*2^n bits in length, and convert it to long[]
            var cIntPad = new uint[2 * intDigits.Length - 2];
            Array.Copy(cInt, 0, cIntPad, cIntPad.Length - cInt.Length, cInt.Length);
            var c = ToLongArrayEven(cIntPad);
            // reduce c mod Fn which makes the first c.length/2-1 longs zero; return the others
            ReduceWide(c);
            Array.Copy(c, c.Length / 2 - 1, _digits, 0, c.Length / 2 + 1);
        }
    }

    ///<summary>
    ///Reduces this number modulo Fn. <see cref="_digits"/>[0] will be 0 or 1.
    ///</summary>
    private void Reduce()
    {
        // Reduction modulo Fn is done by subtracting the most significant long from the least significant long
        var len = _digits.Length;
        var bi = _digits[0];
        var diff = _digits[len - 1] - bi;
        var borrow = (((ulong)diff >> 63) > ((ulong)_digits[len - 1] >> 63) - ((ulong)bi >> 63));   // borrow if signBit(diff) > signBit(digits[len-1])-signBit(digits[0])
        _digits[len - 1] = diff;
        _digits[0] = 0;   // because we subtracted digits[0] from digits[len-1]
        if (borrow)
        {
            var i = len - 2;
            do
            {
                diff = _digits[i] - 1;
                _digits[i] = diff;
                borrow = diff == -1;
                i--;
            } while (borrow && i >= 0);
        }

        // if we borrowed from the most significant long, subtract 2^2^n which is the same as adding 1 (mod Fn)
        if (borrow)
        {
            var i = _digits.Length - 1;
            var carry = true;
            _digits[0] = 0;   // increment digits[0] by 1 to make it 0
            while (carry && i >= 0)
            {
                var sum = _digits[i] + 1;
                _digits[i] = sum;
                carry = sum == 0;
                i--;
            }
        }
    }

    ///<summary>
    ///Like <see cref="Reduce"/> but works on an array of length 2^(n+1).
    ///</summary>
    private static void ReduceWide(long[] a)
    {
        // Reduction modulo Fn is done by subtracting the upper half from the lower half
        var len = a.Length;
        var carry = false;
        for (var i = len - 1; i >= len / 2; i--)
        {
            var bi = a[i - len / 2];
            var diff = a[i] - bi;
            if (carry)
                diff--;
            carry = (((ulong)diff >> 63) > ((ulong)a[i] >> 63) - ((ulong)bi >> 63));   // carry if signBit(diff) > signBit(a)-signBit(b)
            a[i] = diff;
        }
        for (var i = len / 2 - 1; i >= 0; i--)
            a[i] = 0;
        // if result is negative, add Fn; since Fn is congruent to 1 (mod 2^n), it suffices to add 1
        if (carry)
        {
            var j = len - 1;
            do
            {
                var sum = a[j] + 1;
                a[j] = sum;
                carry = sum == 0;
                j--;
                if (j <= 0)
                    break;
            } while (carry);
        }
    }

    ///<summary>
    ///Like <see cref="ReduceWide"/> but works on an an <see cref="uint"/> array.
    ///</summary>
    public static void Reduce(uint[] digits)
    {
        // Reduction modulo Fn is done by subtracting the most significant int from the least significant int
        var len = digits.Length;
        var bi = digits[0];
        var diff = digits[len - 1] - bi;
        var borrow = ((diff >> 31) > (digits[len - 1] >> 31) - (bi >> 31));   // borrow if signBit(diff) > signBit(digits[len-1])-signBit(digits[0])
        digits[len - 1] = diff;
        digits[0] = 0;   // because we subtracted digits[0] from digits[len-1]
        if (borrow)
        {
            var i = len - 2;
            do
            {
                diff = digits[i] - 1;
                digits[i] = diff;
                borrow = diff == uint.MaxValue;
                i--;
            } while (borrow && i >= 0);
        }

        // if we borrowed from the most significant int, subtract 2^2^n which is the same as adding 1 (mod Fn)
        if (borrow)
        {
            var i = digits.Length - 1;
            var carry = true;
            digits[0] = 0;   // increment digits[0] by 1 to make it 0
            while (carry && i >= 0)
            {
                var sum = digits[i] + 1;
                digits[i] = sum;
                carry = sum == 0;
                i--;
            }
        }
    }

    ///<summary>
    ///<para>Multiplies this number by 2^-shiftAmtBits modulo 2^2^n + 1 where 2^n = <code>(digits.length-1)*64</code>.</para>
    ///<para>"Right" means towards the higher array indices and the lower bits.</para>
    ///<para>This is equivalent to extending the number to <code>2*(digits.length-1)</code> longs and cyclicly
    ///shifting to the right by<code>shiftAmt</code> bits.</para>
    ///The result is placed in the second argument.
    ///</summary>
    public void ShiftRight(int shiftAmtBits, MutableModFn b)
    {
        var len = _digits.Length;
        if (shiftAmtBits > 64 * (len - 1))
        {
            ShiftLeft(64 * 2 * (len - 1) - shiftAmtBits, b);
            return;
        }

        var shiftAmtLongs = shiftAmtBits / 64;   // number of longs to shift
        if (shiftAmtLongs > 0)
        {
            var borrow = false;
            var diff = 0L;
            // shift the digits that stay positive, except a[len-1] which is special
            for (int i = 1; i < len - shiftAmtLongs; i++)
            {
                diff = _digits[i];
                if (borrow)
                    diff--;
                b._digits[shiftAmtLongs + i] = diff;
                borrow = diff == -1 && borrow;
            }

            // subtract a[len-1] from a[0]
            diff = _digits[0] - _digits[len - 1];
            if (borrow)
            {
                diff--;
                borrow = diff == -1;
            }
            else
                borrow = _digits[0] == 0 && _digits[len - 1] != 0;   // a[0] can only be 0 or 1; if digits[0]!=0, digits[len-1]==0
            b._digits[shiftAmtLongs] = diff;

            // using the fact that adding x*(Fn-1) is the same as subtracting x,
            // subtract digits shifted off the right, except for a[0] which is special
            for (var i = 1; i < shiftAmtLongs; i++)
            {
                b._digits[shiftAmtLongs - i] = -_digits[len - 1 - i];
                if (borrow)
                    b._digits[shiftAmtLongs - i]--;
                borrow = b._digits[shiftAmtLongs - i] != 0 || borrow;
            }

            // if we borrowed from the most significant long, add 1 to the overall number
            var carry = borrow;
            if (carry)
            {
                // increment b[0] and decrement b[len-1]
                b._digits[0] = 0;
                var i = len - 1;
                do
                {
                    var sum = b._digits[i] + 1;
                    b._digits[i] = sum;
                    carry = sum == 0;
                    i--;
                } while (carry && i >= 0);
            }
            else
                b._digits[0] = 0;
        }
        else
            Array.Copy(_digits, 0, b._digits, 0, len);

        var shiftAmtFrac = shiftAmtBits % 64;
        if (shiftAmtFrac != 0)
        {
            long bhi = b._digits[len - 1] << (64 - shiftAmtFrac);

            // do remaining digits
            b._digits[len - 1] = (long)((ulong)b._digits[len - 1] >> shiftAmtFrac);
            for (var i = len - 1; i > 0; i--)
            {
                b._digits[i] |= b._digits[i - 1] << (64 - shiftAmtFrac);
                b._digits[i - 1] = (long)((ulong)b._digits[i - 1] >> shiftAmtFrac);
            }

            // b[len-1] spills over into b[1]
            var diff = b._digits[1] - bhi;
            var borrow = (((ulong)diff >> 63) > ((ulong)b._digits[1] >> 63) - ((ulong)bhi >> 63));   // borrow if signBit(diff) > signBit(a)-signBit(b)
            b._digits[1] = diff;

            // if we borrowed from b[0], add 1 to the overall number
            var carry = borrow;
            if (carry)
            {
                // increment b[0] and decrement b[len-1]
                b._digits[0] = 0;
                var i = len - 1;
                do
                {
                    var sum = b._digits[i] + 1;
                    b._digits[i] = sum;
                    carry = sum == 0;
                    i--;
                } while (carry && i >= 0);
            }
            else
                b._digits[0] = 0;
        }
    }

    ///<summary>
    ///<para>Multiplies this number by 2^-shiftAmtBits modulo 2^2^n + 1 where 2^n = <code>(digits.length-1)*64</code>.</para>
    ///<para>"Left" means towards the higher array indices and the lower bits.</para>
    ///<para>This is equivalent to extending the number to <code>2*(digits.length-1)</code> longs and cyclicly
    ///shifting to the left by<code>shiftAmt</code> bits.</para>
    ///The result is placed in the second argument.
    ///</summary>
    public void ShiftLeft(int shiftAmtBits, MutableModFn b)
    {
        var len = _digits.Length;

        if (shiftAmtBits > 64 * (len - 1))
        {
            ShiftRight(64 * 2 * (len - 1) - shiftAmtBits, b);
            return;
        }

        var shiftAmtLongs = shiftAmtBits / 64;   // number of longs to shift
        if (shiftAmtLongs > 0)
        {
            var borrow = false;
            // using the fact that adding x*(Fn-1) is the same as subtracting x,
            // subtract digits shifted outside the [0..Fn-2] range, except for digits[0] which is special
            for (var i = 0; i < shiftAmtLongs; i++)
            {
                b._digits[len - 1 - i] = -_digits[shiftAmtLongs - i];
                if (borrow)
                    b._digits[len - 1 - i]--;
                borrow = b._digits[len - 1 - i] != 0 || borrow;
            }

            // subtract digits[0] from digits[len-1] (they overlap unless numElements=len-1)
            long diff;
            if (shiftAmtLongs < len - 1)
                diff = _digits[len - 1] - _digits[0];
            else   // no overlap
                diff = -_digits[0];
            if (borrow)
            {
                diff--;
                borrow = diff == -1;
            }
            else
                borrow = _digits[0] == 1 && diff == -1;   // digits[0] can only be 0 or 1
            b._digits[len - 1 - shiftAmtLongs] = diff;

            // finally, shift the digits that stay in the [0..Fn-2] range
            for (var i = 1; i < len - shiftAmtLongs - 1; i++)
            {
                diff = _digits[len - 1 - i];
                if (borrow)
                    diff--;
                b._digits[len - 1 - shiftAmtLongs - i] = diff;
                borrow = diff == -1 && borrow;
            }

            // if we borrowed from the most significant long, add 1 to the overall number
            var carry = borrow;
            if (carry)
            {
                // increment b[0] and decrement b[len-1]
                b._digits[0] = 0;
                int i = len - 1;
                do
                {
                    var sum = b._digits[i] + 1;
                    b._digits[i] = sum;
                    carry = sum == 0;
                    i--;
                } while (carry && i >= 0);
            }
            else
                b._digits[0] = 0;
        }
        else
            Array.Copy(_digits, 0, b._digits, 0, len);

        var shiftAmtFrac = shiftAmtBits % 64;
        if (shiftAmtFrac != 0)
        {
            b._digits[0] <<= shiftAmtFrac;   // no spill-over because 0<=digits[0]<=1 and shiftAmtFrac<=63
            for (var i = 1; i < len; i++)
            {
                b._digits[i - 1] |= (long)((ulong)b._digits[i] >> (64 - shiftAmtFrac));
                b._digits[i] <<= shiftAmtFrac;
            }
        }

        b.Reduce();
    }

    ///<summary>
    ///<paramref name="digits"/>.Length must be an even number
    ///</summary>
    private static long[] ToLongArrayEven(uint[] digits)
    {
        var longDigits = new long[digits.Length / 2];
        for (var i = 0; i < longDigits.Length; i++)
            longDigits[i] = (((long)digits[2 * i]) << 32) | (digits[2 * i + 1] & 0xFFFFFFFFL);
        return longDigits;
    }

    ///<summary>
    ///<paramref name="digits"/>.Length must be an odd number
    ///</summary>
    public static uint[] ToUIntArrayOdd(long[] digits)
    {
        var intDigits = new uint[digits.Length * 2 - 1];
        intDigits[0] = (uint)digits[0];
        for (var i = 1; i < digits.Length; i++)
        {
            intDigits[2 * i - 1] = (uint)((ulong)digits[i] >> 32);
            intDigits[2 * i] = (uint)(digits[i] & -1);
        }

        return intDigits;
    }

    ///<summary>
    ///<see cref="_digits"/>.Length must be an odd number
    ///</summary>
    public uint[] ToUIntArrayOdd()
    {
        return ToUIntArrayOdd(_digits);
    }
}

Fast Fourier Transform

/// <summary>
/// <para>Performs a modified DFT Fermat Number Transform on an array whose elements are <see cref="uint"/> arrays.</para>
/// <para>The modification is that the first step is omitted because only the upper half of the result is needed.</para>
/// <paramref name="value"/> is assumed to be the lower half of the full array and the upper half is assumed to be all zeros.
/// </summary>
private static void DFT(MutableModFn[] value, int omega)
{
    // arrange the elements of A in a matrix roughly sqrt(A.length) by sqrt(A.length) in size
    var rows = 1 << ((31 - value.Length.NumberOfLeadingZeros()) >> 1);   // number of rows
    var cols = value.Length / rows;   // number of columns

    // step 1: perform an DFT on each column, that is, on the vector
    // A[colIdx], A[colIdx+cols], A[colIdx+2*cols], ..., A[colIdx+(rows-1)*cols].
    for (var i = 0; i < cols; i++)
        DFTDirect(value, omega, rows, rows, cols, i, cols);

    // step 2: multiply by powers of omega
    ApplyDFTWeights(value, omega, rows, cols);

    // step 3 is built into step 1 by making the stride length a multiple of the row length

    // step 4: perform an DFT on each row, that is, on the vector
    // A[rowIdx*cols], A[rowIdx*cols+1], ..., A[rowIdx*cols+cols-1].
    for (var i = 0; i < rows; i++)
        DFTDirect(value, omega, cols, 0, rows, i * cols, 1);
}

/// <summary>
/// Performs a DFT on <paramref name="value"/>.
/// This implementation uses the radix-4 technique which combines two levels of butterflies.
/// </summary>
private static void DFTDirect(MutableModFn[] value, int omega, int len, int expOffset, int expScale, int idxOffset, int stride)
{
    var n = 31 - (2 * len).NumberOfLeadingZeros();   // multiply by 2 because we're doing a half DFT and we need the n that corresponds to the full DFT length
    var v = 1;   // v starts at 1 rather than 0 for the same reason
    MutableModFn d = new MutableModFn(value[0].Length);

    var slen = len >> 1;
    while (slen > 1)
    {   // slen = #consecutive coefficients for which the sign (add/sub) and x are constant
        for (var j = 0; j < len; j += (slen << 1))
        {
            var x1 = GetDFTExponent(n, v + 1, j + expOffset, omega) * expScale;        // for level v+2
            var x2 = GetDFTExponent(n, v, j + expOffset, omega) * expScale;          // for level v+1
            var x3 = GetDFTExponent(n, v + 1, j + slen + expOffset, omega) * expScale;   // for level v+2

            // stride length = stride*slen elements
            var idx0 = stride * j + idxOffset;
            var idx1 = stride * j + stride * slen / 2 + idxOffset;
            var idx2 = idx0 + stride * slen;
            var idx3 = idx1 + stride * slen;

            for (var k = slen - 1; k >= 0; k -= 2)
            {
                // do level v+1
                value[idx2].ShiftLeft(x2, d);
                value[idx0].CopyTo(value[idx2]);
                value[idx0].Add(d);
                value[idx2].Subtract(d);

                value[idx3].ShiftLeft(x2, d);
                value[idx1].CopyTo(value[idx3]);
                value[idx1].Add(d);
                value[idx3].Subtract(d);

                // do level v+2
                value[idx1].ShiftLeft(x1, d);
                value[idx0].CopyTo(value[idx1]);
                value[idx0].Add(d);
                value[idx1].Subtract(d);

                value[idx3].ShiftLeft(x3, d);
                value[idx2].CopyTo(value[idx3]);
                value[idx2].Add(d);
                value[idx3].Subtract(d);

                idx0 += stride;
                idx1 += stride;
                idx2 += stride;
                idx3 += stride;
            }
        }

        v += 2;
        slen >>= 2;
    }

    // if there is an odd number of levels, do the remaining one now
    if (slen > 0)
        for (var j = 0; j < len; j += 2 * slen)
        {
            var x = GetDFTExponent(n, v, j + expOffset, omega) * expScale;
            var idx = stride * j + idxOffset;
            var idx2 = idx + stride * slen;   // stride length = stride*slen elements

            for (var k = slen - 1; k >= 0; k--)
            {
                value[idx2].ShiftLeft(x, d);
                value[idx].CopyTo(value[idx2]);
                value[idx].Add(d);
                value[idx2].Subtract(d);
                idx += stride;
                idx2 += stride;
            }
        }
}

/// <summary>
/// <para>Returns the power to which to raise omega in a DFT.</para>
/// When <code>omega</code>=4, this method doubles the exponent so
/// <code>omega</code> can be assumed always to be 2 in the
/// <see cref="DFTDirect"/> and <see cref="IDFTDirect"/> methods.
/// </summary>
private static int GetDFTExponent(int n, int v, int idx, int omega)
{
    // x = 2^(n-1-v) * s, where s is the v (out of n) high bits of idx in reverse order
    var x = (int)(ReverseBits((uint)idx >> (n - v)) >> (32 - v));
    x <<= n - v - 1;

    // if omega=4, double the shift amount
    if (omega == 4)
        x <<= 1;

    return x;
}

/// <summary>
/// Multiplies vector elements by powers of omega (aka twiddle factors). Used by Bailey's algorithm.
/// </summary>
private static void ApplyDFTWeights(MutableModFn[] value, int omega, int rows, int cols)
{
    var v = 31 - rows.NumberOfLeadingZeros() + 1;

    for (var i = 0; i < rows; i++)
        for (var j = 0; j < cols; j++)
        {
            var idx = i * cols + j;
            var temp = new MutableModFn(value[idx].Length);
            var shiftAmt = GetBaileyShiftAmount(i, j, rows, v);
            if (omega == 4)
                shiftAmt *= 2;

            value[idx].ShiftLeft(shiftAmt, temp);
            temp.CopyTo(value[idx]);
        }
}

private static int GetBaileyShiftAmount(int i, int j, int rows, int v)
{
    var iRev = (int)(ReverseBits((uint)(i + rows)) >> (32 - v));
    return iRev * j;
}

/// <summary>
/// <para>Performs a modified Inverse Fermat Number Transform on an array whose elements are <see cref="uint"/> arrays.</para>
/// <para>The modification is that the last step (the one where the upper half is subtracted from the lower half) is omitted.</para>
/// <paramref name="value"/> is assumed to be the upper half of the full array and the lower half is assumed to be all zeros.
/// </summary>
private static void IDFT(MutableModFn[] value, int omega)
{
    // arrange the elements of A in a matrix roughly sqrt(A.length) by sqrt(A.length) in size
    var rows = 1 << ((31 - value.Length.NumberOfLeadingZeros()) >> 1);   // number of rows
    var cols = value.Length / rows;   // number of columns

    // step 1: perform an IDFT on each row, that is, on the vector
    // A[rowIdx*cols], A[rowIdx*cols+1], ..., A[rowIdx*cols+cols-1].
    for (var i = 0; i < rows; i++)
        IDFTDirect(value, omega, cols, 0, rows, i * cols, 1);

    // step 2: multiply by powers of omega
    ApplyIDFTWeights(value, omega, rows, cols);

    // step 3 is built into step 4 by making the stride length a multiple of the row length

    // step 4: perform an IDFT on each column, that is, on the vector
    // A[colIdx], A[colIdx+cols], A[colIdx+2*cols], ..., A[colIdx+(rows-1)*cols].
    for (var i = 0; i < cols; i++)
        IDFTDirect(value, omega, rows, rows, cols, i, cols);
}

/// <summary>
/// This implementation uses the radix-4 technique which combines two levels of butterflies.
/// </summary>
private static void IDFTDirect(MutableModFn[] value, int omega, int len, int expOffset, int expScale, int idxOffset, int stride)
{
    var n = 31 - (len << 1).NumberOfLeadingZeros();   // multiply by 2 because we're doing a half DFT and we need the n that corresponds to the full DFT length
    var v = 31 - len.NumberOfLeadingZeros();
    var c = new MutableModFn(value[0].Length);

    var slen = 1;
    while (slen <= len / 4)
    {   // slen = #consecutive coefficients for which the sign (add/sub) and x are constant
        for (var j = 0; j < len; j += 4 * slen)
        {
            var x1 = GetDFTExponent(n, v, j + expOffset, omega) * expScale + 1;          // for level v
            var x2 = GetDFTExponent(n, v - 1, j + expOffset, omega) * expScale + 1;        // for level v-1
            var x3 = GetDFTExponent(n, v, j + slen * 2 + expOffset, omega) * expScale + 1;   // for level v

            // stride length = stride*slen elements
            var idx0 = stride * j + idxOffset;
            var idx1 = stride * j + stride * slen + idxOffset;
            var idx2 = idx0 + stride * slen * 2;
            var idx3 = idx1 + stride * slen * 2;

            for (var k = slen - 1; k >= 0; k--)
            {
                // do level v
                value[idx0].CopyTo(c);
                value[idx0].Add(value[idx1]);
                value[idx0].ShiftRight(1, value[idx0]);
                c.Subtract(value[idx1]);
                c.ShiftRight(x1, value[idx1]);

                value[idx2].CopyTo(c);
                value[idx2].Add(value[idx3]);
                value[idx2].ShiftRight(1, value[idx2]);
                c.Subtract(value[idx3]);
                c.ShiftRight(x3, value[idx3]);

                // do level v-1
                value[idx0].CopyTo(c);
                value[idx0].Add(value[idx2]);
                value[idx0].ShiftRight(1, value[idx0]);
                c.Subtract(value[idx2]);
                c.ShiftRight(x2, value[idx2]);

                value[idx1].CopyTo(c);
                value[idx1].Add(value[idx3]);
                value[idx1].ShiftRight(1, value[idx1]);
                c.Subtract(value[idx3]);
                c.ShiftRight(x2, value[idx3]);

                idx0 += stride;
                idx1 += stride;
                idx2 += stride;
                idx3 += stride;
            }
        }

        v -= 2;
        slen *= 4;
    }

    // if there is an odd number of levels, do the remaining one now
    if (slen <= len / 2)
        for (int j = 0; j < len; j += 2 * slen)
        {
            int x = GetDFTExponent(n, v, j + expOffset, omega) * expScale + 1;
            int idx = stride * j + idxOffset;
            int idx2 = idx + stride * slen;   // stride length = stride*slen elements

            for (int k = slen - 1; k >= 0; k--)
            {
                value[idx].CopyTo(c);
                value[idx].Add(value[idx2]);
                value[idx].ShiftRight(1, value[idx]);

                c.Subtract(value[idx2]);
                c.ShiftRight(x, value[idx2]);
                idx += stride;
                idx2 += stride;
            }
        }
}

/// <summary>
/// Divides vector elements by powers of omega (aka twiddle factors)
/// </summary>
private static void ApplyIDFTWeights(MutableModFn[] value, int omega, int rows, int cols)
{
    var v = 31 - rows.NumberOfLeadingZeros() + 1;

    for (int i = 0; i < rows; i++)
        for (var j = 0; j < cols; j++)
        {
            var idx = i * cols + j;
            var temp = new MutableModFn(value[idx].Length);
            var shiftAmt = GetBaileyShiftAmount(i, j, rows, v);
            if (omega == 4)
                shiftAmt *= 2;

            value[idx].ShiftRight(shiftAmt, temp);
            temp.CopyTo(value[idx]);
        }
}

private static uint[][] ToIntArray(MutableModFn[] value)
{
    var aInt = new uint[value.Length][];
    for (var i = 0; i < value.Length; i++)
        aInt[i] = value[i].ToUIntArrayOdd();

    return aInt;
}

/// <summary>
/// Calls <see cref="MutableModFn.Multiply(MutableModFn)"/> for each element of
/// <paramref name="left"/> and <paramref name="right"/> and places the result into <paramref name="left"/>.
/// </summary>
private static void MultiplyElements(MutableModFn[] left, MutableModFn[] right)
{
    for (int i = 0; i < left.Length; i++)
        left[i].Multiply(right[i]);
}

/// <summary>
/// Calls <see cref="MutableModFn.Square"/> for each element of
/// <paramref name="value"/> and places the result into <paramref name="value"/>.
/// </summary>
private static void SquareElements(MutableModFn[] value)
{
    for (int i = 0; i < value.Length; i++)
        value[i].Square();
}

/// <summary>
/// <para>Adds two numbers, <paramref name="left"/> and <paramref name="right"/>, 
/// after shifting <paramref name="right"/> by <paramref name="numElements"/> elements.</para>
/// <para>both numbers are given as <see cref="uint"/> arrays and must be <b>positive</b> numbers
/// (meaning they are interpreted as unsigned).</para>
/// The result is returned in the first argument.
/// If any elements of <paramref name="right"/> are shifted outside the valid range
/// for <paramref name="left"/>, they are dropped.
/// </summary>
private static void AddShifted(uint[] left, uint[] right, int numElements)
{
    var carry = false;
    var aIdx = left.Length - 1 - numElements;
    var bIdx = right.Length - 1;
    var i = Math.Min(aIdx, bIdx);
    while (i >= 0)
    {
        var ai = left[aIdx];
        var sum = ai + right[bIdx];
        if (carry)
            sum++;

        carry = ((sum >> 31) < (ai >> 31) + (right[bIdx] >> 31));   // carry if signBit(sum) < signBit(a)+signBit(b)
        left[aIdx] = sum;
        i--;
        aIdx--;
        bIdx--;
    }
    while (carry && aIdx >= 0)
    {
        left[aIdx]++;
        carry = left[aIdx] == 0;
        aIdx--;
    }
}

/// <summary>
/// <para>Adds two positive numbers (meaning they are interpreted as unsigned) modulo 2^<paramref name="numBits"/></para>
/// Both input values are given as <see cref="uint"/> arrays.
/// The result is returned in the first argument.
/// </summary>
private static void AddModPow2(uint[] left, uint[] right, int numBits)
{
    var numElements = (numBits + 31) >> 5;
    var carry = false;
    int i;
    var aIdx = left.Length - 1;
    var bIdx = right.Length - 1;
    for (i = numElements - 1; i >= 0; i--)
    {
        var sum = left[aIdx] + right[bIdx];
        if (carry)
            sum++;
        carry = ((sum >> 31) < (left[aIdx] >> 31) + (right[bIdx] >> 31));   // carry if signBit(sum) < signBit(a)+signBit(b)
        left[aIdx] = sum;
        aIdx--;
        bIdx--;
    }
    if (numElements > 0)
        left[aIdx + 1] &= uint.MaxValue >> (32 - (numBits % 32));
    for (; aIdx >= 0; aIdx--)
        left[aIdx] = 0;
}

/// <summary>
/// <para>Subtracts two positive numbers (meaning they are interpreted as unsigned) modulo 2^<paramref name="numBits"/></para>
/// Both input values are given as <see cref="uint"/> arrays.
/// The result is returned in the first argument.
/// </summary>
private static void SubModPow2(uint[] left, uint[] right, int numBits)
{
    var numElements = (numBits + 31) >> 5;
    var carry = false;
    int i;
    var aIdx = left.Length - 1;
    var bIdx = right.Length - 1;
    for (i = numElements - 1; i >= 0; i--)
    {
        var diff = left[aIdx] - right[bIdx];
        if (carry)
            diff--;

        carry = ((diff >> 31) > (left[aIdx] >> 31) - (right[bIdx] >> 31));   // carry if signBit(diff) > signBit(a)-signBit(b)
        left[aIdx] = diff;
        aIdx--;
        bIdx--;
    }
    if (numElements > 0)
        left[aIdx + 1] &= uint.MaxValue >> (32 - (numBits % 32));
    for (; aIdx >= 0; aIdx--)
        left[aIdx] = 0;
}

Unsigned Schönhage–Strassen Multiplication

private static readonly int KARATSUBA_SQUARE_THRESHOLD = 128;

/// <summary>
/// SchönhageStrassen非負大數乘法,陣列第一個<see cref="uint"/>存放最高32位,最後一個<see cref="uint"/>存放最低32位。
/// </summary>
public static uint[] MultiplySchönhageStrassenNonegative(uint[] left, uint[] right)
{
    if (left.Length < KARATSUBA_SQUARE_THRESHOLD ||
        right.Length < KARATSUBA_SQUARE_THRESHOLD)
        return MultiplyNonegative(left, right);

    var square = left.SequenceEqual(right);

    // set M to the number of binary digits in a or b, whichever is greater
    var M = Math.Max(left.Length << 5, right.Length << 5);

    // find the lowest m such that m>=log2(2M)
    var m = 32 - (2 * M - 1 - 1).NumberOfLeadingZeros();

    var n = (m >> 1) + 1;

    // split a and b into pieces 1<<(n-1) bits long; assume n>=6 so pieces start and end at int boundaries
    var even = (m & 1) == 0;
    var numPieces = even ? 1 << n : 1 << (n + 1);
    var pieceSize = 1 << (n - 1 - 5);   // in ints

    // zi mod 2^(n+2): build u and v from a and b, allocating 3n+5 bits in u and v per n+2 bits from a and b, resp.
    var numPiecesA = (left.Length + pieceSize) / pieceSize;
    var u = new uint[(numPiecesA * (3 * n + 5) + 31) / 32];
    var uBitLength = 0;
    for (var i = 0; i < numPiecesA && i * pieceSize < left.Length; i++)
    {
        AppendBits(u, uBitLength, left, i * pieceSize, n + 2);
        uBitLength += 3 * n + 5;
    }
    uint[] gamma;
    if (square)
        gamma = SquareToomCook3((u, true)).Item1;   // gamma = u * u
    else
    {
        var numPiecesB = (right.Length + pieceSize) / pieceSize;
        var v = new uint[(numPiecesB * (3 * n + 5) + 31) / 32];
        var vBitLength = 0;
        for (var i = 0; i < numPiecesB && i * pieceSize < right.Length; i++)
        {
            AppendBits(v, vBitLength, right, i * pieceSize, n + 2);
            vBitLength += 3 * n + 5;
        }
        gamma = MultiplySchönhageStrassenNonegative(u, v);   // gamma = u * v
    }
    var gammai = SplitBits(gamma, 3 * n + 5);
    var halfNumPcs = numPieces / 2;

    var zi = new uint[gammai.Length][];
    for (var i = 0; i < gammai.Length; i++)
        zi[i] = gammai[i];
    for (var i = 0; i < gammai.Length - halfNumPcs; i++)
        SubModPow2(zi[i], gammai[i + halfNumPcs], n + 2);
    for (var i = 0; i < gammai.Length - 2 * halfNumPcs; i++)
        AddModPow2(zi[i], gammai[i + 2 * halfNumPcs], n + 2);
    for (var i = 0; i < gammai.Length - 3 * halfNumPcs; i++)
        SubModPow2(zi[i], gammai[i + 3 * halfNumPcs], n + 2);

    // zr mod Fn
    var ai = Split(left, halfNumPcs, pieceSize, (1 << (n - 6)) + 1);   // assume n>=6
    MutableModFn[] bi = null;
    if (!square)
        bi = Split(right, halfNumPcs, pieceSize, (1 << (n - 6)) + 1);
    var omega = even ? 4 : 2;
    if (square)
    {
        DFT(ai, omega);
        SquareElements(ai);
    }
    else
    {
        DFT(ai, omega);
        DFT(bi, omega);
        MultiplyElements(ai, bi);
    }
    var c = ai;
    IDFT(c, omega);
    var cInt = ToIntArray(c);

    var z = new uint[(1 << (m - 5)) + 1];
    // calculate zr mod Fm from zr mod Fn and zr mod 2^(n+2), then add to z
    // note: z is an int[] rather than a MutableBigInteger because MBI.addShifted() seems to be much slower than BI.addShifted()
    for (var i = 0; i < halfNumPcs; i++)
    {
        var eta = i >= zi.Length ? new uint[(n + 2 + 31) / 32] : zi[i];

        // zi = delta = (zi-c[i]) % 2^(n+2)
        SubModPow2(eta, cInt[i], n + 2);

        // z += zr<<shift = [ci + delta*(2^2^n+1)] << [i*2^(n-1)]
        var shift = i * (1 << (n - 1 - 5));   // assume n>=6
        AddShifted(z, cInt[i], shift);
        AddShifted(z, eta, shift);
        AddShifted(z, eta, shift + (1 << (n - 5)));
    }

    MutableModFn.Reduce(z);   // assume m>=5
    return z;
}

/// <summary>
/// Reads <paramref name="rightBitLength"/> bits from <paramref name="right"/>, starting at array index
/// <paramref name="rightStart"/>, and copies them into <paramref name="left"/>, starting at bit
/// <paramref name="leftBitLength"/>.The result is returned in <paramref name="left"/>.
/// </summary>
private static void AppendBits(uint[] left, int leftBitLength, uint[] right, int rightStart, int rightBitLength)
{
    var aIdx = left.Length - 1 - (leftBitLength >> 5);
    var bit32 = leftBitLength % 32;

    var bIdx = right.Length - 1 - rightStart;
    var bIdxStop = bIdx - (rightBitLength >> 5);
    while (bIdx > bIdxStop)
    {
        if (bit32 > 0)
        {
            left[aIdx] |= right[bIdx] << bit32;
            aIdx--;
            left[aIdx] = right[bIdx] >> (32 - bit32);
        }
        else
        {
            left[aIdx] = right[bIdx];
            aIdx--;
        }
        bIdx--;
    }

    if (rightBitLength % 32 > 0)
    {
        var bi = right[bIdx];
        bi &= uint.MaxValue >> (32 - rightBitLength % 32);
        left[aIdx] |= bi << bit32;
        if (bit32 + (rightBitLength % 32) > 32)
            left[aIdx - 1] = bi >> (32 - bit32);
    }
}

/// <summary>
/// Divides an <see cref="uint"/> array into pieces <paramref name="bitLength"/> bits long.
/// </summary>
private static uint[][] SplitBits(uint[] value, int bitLength)
{
    var aIntIdx = value.Length - 1;
    var aBitIdx = 0;
    var numPieces = (value.Length * 32 + bitLength - 1) / bitLength;
    var pieceLength = (bitLength + 31) / 32;   // in ints
    var b = new uint[numPieces][];
    for (var i = 0; i < b.Length; i++)
    {
        b[i] = new uint[pieceLength];
        var bitsRemaining = Math.Min(bitLength, value.Length * 32 - i * bitLength);
        var bIntIdx = bitLength / 32;
        if (bitLength % 32 == 0)
            bIntIdx--;
        var bBitIdx = 0;
        while (bitsRemaining > 0)
        {
            var bitsToCopy = Math.Min(32 - aBitIdx, 32 - bBitIdx);
            bitsToCopy = Math.Min(bitsRemaining, bitsToCopy);
            var mask = value[aIntIdx] >> aBitIdx;
            mask &= uint.MaxValue >> (32 - bitsToCopy);
            mask <<= bBitIdx;
            b[i][bIntIdx] |= mask;
            bitsRemaining -= bitsToCopy;
            aBitIdx += bitsToCopy;
            if (aBitIdx >= 32)
            {
                aBitIdx -= 32;
                aIntIdx--;
            }
            bBitIdx += bitsToCopy;
            if (bBitIdx >= 32)
            {
                bBitIdx -= 32;
                bIntIdx--;
            }
        }
    }

    return b;
}

/// <summary>
/// Splits an <see cref="uint"/> array into pieces of <paramref name="numPieces"/> <see cref="long"/>s each,
/// pads each piece to <paramref name="targetPieceSize"/> <see cref="long"/>s,
/// and wraps it in a <see cref="MutableModFn"/>
/// (this implies<paramref name="targetPieceSize"/> = 2^k + 1 for some k).
/// </summary>
private static MutableModFn[] Split(uint[] value, int numPieces, int sourcePieceSize, int targetPieceSize)
{
    var ai = new MutableModFn[numPieces];
    var aIdx = value.Length - sourcePieceSize;
    var pieceIdx = 0;
    while (aIdx >= 0)
    {
        var digits = new long[targetPieceSize];
        for (var i = 0; i < sourcePieceSize; i += 2)
            digits[targetPieceSize - sourcePieceSize / 2 + i / 2] = (((long)value[aIdx + i]) << 32) | (value[aIdx + i + 1] & 0xFFFFFFFFL);

        ai[pieceIdx] = new MutableModFn(digits);
        aIdx -= sourcePieceSize;
        pieceIdx++;
    }

    var digits2 = new long[targetPieceSize];
    if ((value.Length % sourcePieceSize) % 2 == 0)
        for (int i = 0; i < value.Length % sourcePieceSize; i += 2)
            digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2] = (((long)value[i]) << 32) | (value[i + 1] & 0xFFFFFFFFL);
    else
    {
        for (int i = 0; i < value.Length % sourcePieceSize - 2; i += 2)
        {
            digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2] = ((long)value[i + 1]) << 32;
            digits2[targetPieceSize - (value.Length % sourcePieceSize) / 2 + i / 2 - 1] |= value[i] & 0xFFFFFFFFL;
        }
        // the remaining half-long
        digits2[targetPieceSize - 1] |= value[value.Length % sourcePieceSize - 1] & 0xFFFFFFFFL;
    }

    ai[pieceIdx] = new MutableModFn(digits2);
    while (++pieceIdx < numPieces)
        ai[pieceIdx] = new MutableModFn(targetPieceSize);

    return ai;
}

BigInteger to/from (uint[], bool)

用(uint[], bool)來表示有符號大數,其中uint[]是大數的絕對值,bool為false時是負數。

/// <summary>
/// (<see cref="uint"/>[], <see cref="bool"/>) to <see cref="BigInteger"/>
/// </summary>
private BigInteger ValueOf((uint[], bool) value)
{
    var result = BigInteger.Zero;
    var array = value.Item1.SkipWhile(num => num == 0u).ToArray();
    foreach (var num in array)
    {
        result <<= 32;
        result |= (num & 0xFFFF_FFFF);
    }

    return value.Item2 ? result : -result;
}


/// <summary>
/// <see cref="BigInteger"/> to (<see cref="uint"/>[], <see cref="bool"/>)
/// </summary>
private (uint[], bool) ToTuple(BigInteger value)
{
    var positive = BigInteger.Abs(value);

    var byteCount = positive.GetByteCount();
    var len = (int)Math.Ceiling(byteCount / 4d);
    var result = new uint[len];
    for (var i = len - 1; i >= 0; --i)
    {
        result[i] = (uint)(positive & 0xFFFF_FFFF);
        positive >>= 32;
    }

    return (result, value >= 0);
}

Signed Schönhage–Strassen Multiplication

/// <summary>
/// SchönhageStrassen乘法,陣列第一個<see cref="uint"/>存放最高32位,最後一個<see cref="uint"/>存放最低32位。
/// </summary>
public static (uint[], bool) MultiplySchönhageStrassen((uint[], bool) left, (uint[], bool) right)
{
    if (IsZero(left))
        return left;
    if (IsZero(right))
        return right;
    if (IsAbsOne(left))
        return (right.Item1, right.Item2 == left.Item2);
    if (IsAbsOne(right))
        return (left.Item1, left.Item2 == right.Item2);

    return (MultiplySchönhageStrassenNonegative(left.Item1, right.Item1), left.Item2 == right.Item2);
}

測試

[TestMethod]
public void MultiplySchönhageStrassenNonegativeTest()
{
    for (var i = 0; i < 100; ++i)
    {
        var left = BigInteger.Abs(RandomBigInteger());
        var right = BigInteger.Abs(RandomBigInteger());
        var test = MultiplySchönhageStrassenNonegative(ToIntArray(left), ToIntArray(right));
        var expected = left * right;
        Assert.AreEqual(expected, ValueOf(test));
    }
}

[TestMethod]
public void MultiplySchönhageStrassenTest()
{
    for (var i = 0; i < 100; ++i)
    {
        var left = RandomBigInteger();
        var right = RandomBigInteger();
        var test = MultiplySchönhageStrassen(ToTuple(left), ToTuple(right));
        var expected = left * right;
        Assert.AreEqual(expected, ValueOf(test));
    }
}


private BigInteger RandomBigInteger()
{
    var ran = new Random(Guid.NewGuid().GetHashCode());
    var bytes = new byte[ran.Next(300, 500)];
    ran.NextBytes(bytes);

    return new BigInteger(bytes);
}