1. 程式人生 > >FFT快速傅立葉變換以O(NlogN)的時間複雜度實現大數乘

FFT快速傅立葉變換以O(NlogN)的時間複雜度實現大數乘

任意一個整數均能表示成An*10^(n-1) + An-1*10^(n-2) + ... + A2*10^2 + A1*10 + A0的形式,視10為自變數X,則化為一個多項式。兩數相乘轉化為兩多項式相乘。以係數表示法表示的多項式相乘其複雜度為N^2,若採用點值表示法,結合適當的點的選取,能實現O(NlogN)的演算法。

若一個多項式的最高次為N-1,那麼取N個點對(xi, yi)就能夠唯一確定這個多項式,其中各xi相異,yi為對應自變數xi的多項式的值。此時選取N次單位復根即可。具體定理見演算法導論第30章。

此時過程很簡單,先用FFT分別將兩個輸入的多項式由係數表示法轉為點值表示法,即得到係數向量a=(a 0,a 1, ... ,a n-1)對應的離散傅立葉變換(DFT)y=(y 0,y1, ... ,y n-1),此複雜度為O(NlogN),再相乘,此複雜度僅為O(N),最後再用FFT將結果轉化回係數表示法,輸出即可。

另外演算法用到了二進位制平攤反轉置換(亦叫位反轉置換)以將原本的遞迴過程轉為迭代。

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;
#define pi acos(-1.0)
struct complex{
    double r,i;
    complex(double real=0.0,double image=0.0)
    {
        r=real;
        i=image;
    }
    inline complex operator + (const complex a)
    {return complex(r+a.r,i+a.i);}
    inline complex operator - (const complex a)
    {return complex(r-a.r,i-a.i);}
    inline complex operator * (const complex a)
    {return complex(r*a.r-i*a.i,r*a.i+i*a.r);}
};
void brc(complex *y, int L) {
    int i, j, k;
    for (i=1,j=L>>1; i<L-1; ++i) { // 二進位制平攤反轉置換 O(NlogN)
        if (i < j) swap(y[i], y[j]);
        k = L>>1;
        while (j >= k) {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
}
void FFT(complex *y, int L,double isI)
{
    //isI為1是DFT,-1則IDFT
    register int h,i,j,k;
    complex u,t;
    brc(y,L);
    for(h=2;h<=L;h<<=1)//層數自底向上
    {
        //初始化單位復根
        complex Wn( cos(isI*2*pi/h) , sin(isI*2*pi/h) );
        for(j=0;j<L;j+=h)  // 原序列被分成了L/h段h長序列
        {
            complex w(1,0);
            for(k=j;k<j+h/2;k++) //蝴蝶操作
            {
                u=y[k];
                t=w*y[k+h/2]; //按層配對
                y[k]=u+t;
                y[k+h/2]=u-t;
                w=w*Wn; //更新旋轉因子
            }
        }
    }
    if(isI==-1)
        for(i=0;i<L;i++)
            y[i].r/=L;
}
const int N = 50024;
int ans[N<<2];
complex a[N<<2], b[N<<2];
char num1[N], num2[N];

int main()
{
    int i,j;
    while(~scanf("%s%s",num1,num2))
    {
        memset(ans,0,sizeof(ans));
        int len1=strlen(num1),len2=strlen(num2),L=1;
        int ll=len1+len2-1;
        while(L<ll) L<<=1;
        for(i=len1-1,j=0;i>=0;--i,++j)
        {
            a[j]=complex(num1[i]-'0',0); //高位實為多項式的低位
        }
        for(i=len2-1,j=0;i>=0;--i,++j)
        {
            b[j]=complex(num2[i]-'0',0);
        }
        for(i=len1;i<L;++i) a[i]=complex(0,0);//補0
        for(i=len2;i<L;++i) b[i]=complex(0,0);
        FFT(a,L,1);
        FFT(b,L,1);
        for(i=0;i<L;++i)
        {
            a[i]=a[i]*b[i];
        }
        FFT(a,L,-1);
        for(i=0;i<L;++i) ans[i]=a[i].r+0.5;
        for(i=0;i<L;++i)
        {
            ans[i+1]+=ans[i]/10; //進位
            ans[i]%=10;
        }
        int p=L;
        while(!ans[p] && p) --p;
        while(p>=0) printf("%d", ans[p--]);
        puts("");
    }
    return 0;
}