FFT快速傅立葉變換以O(NlogN)的時間複雜度實現大數乘
阿新 • • 發佈:2019-02-18
任意一個整數均能表示成An*10^(n-1) + An-1*10^(n-2) + ... + A2*10^2 + A1*10 + A0的形式,視10為自變數X,則化為一個多項式。兩數相乘轉化為兩多項式相乘。以係數表示法表示的多項式相乘其複雜度為N^2,若採用點值表示法,結合適當的點的選取,能實現O(NlogN)的演算法。
若一個多項式的最高次為N-1,那麼取N個點對(xi, yi)就能夠唯一確定這個多項式,其中各xi相異,yi為對應自變數xi的多項式的值。此時選取N次單位復根即可。具體定理見演算法導論第30章。
此時過程很簡單,先用FFT分別將兩個輸入的多項式由係數表示法轉為點值表示法,即得到係數向量a=(a 0,a 1, ... ,a n-1)對應的離散傅立葉變換(DFT)y=(y
0,y1, ... ,y n-1),此複雜度為O(NlogN),再相乘,此複雜度僅為O(N),最後再用FFT將結果轉化回係數表示法,輸出即可。
另外演算法用到了二進位制平攤反轉置換(亦叫位反轉置換)以將原本的遞迴過程轉為迭代。
#include <cstdio> #include <cstdlib> #include <cstring> #include <cmath> #include <iostream> #include <algorithm> using namespace std; #define pi acos(-1.0) struct complex{ double r,i; complex(double real=0.0,double image=0.0) { r=real; i=image; } inline complex operator + (const complex a) {return complex(r+a.r,i+a.i);} inline complex operator - (const complex a) {return complex(r-a.r,i-a.i);} inline complex operator * (const complex a) {return complex(r*a.r-i*a.i,r*a.i+i*a.r);} }; void brc(complex *y, int L) { int i, j, k; for (i=1,j=L>>1; i<L-1; ++i) { // 二進位制平攤反轉置換 O(NlogN) if (i < j) swap(y[i], y[j]); k = L>>1; while (j >= k) { j -= k; k >>= 1; } j += k; } } void FFT(complex *y, int L,double isI) { //isI為1是DFT,-1則IDFT register int h,i,j,k; complex u,t; brc(y,L); for(h=2;h<=L;h<<=1)//層數自底向上 { //初始化單位復根 complex Wn( cos(isI*2*pi/h) , sin(isI*2*pi/h) ); for(j=0;j<L;j+=h) // 原序列被分成了L/h段h長序列 { complex w(1,0); for(k=j;k<j+h/2;k++) //蝴蝶操作 { u=y[k]; t=w*y[k+h/2]; //按層配對 y[k]=u+t; y[k+h/2]=u-t; w=w*Wn; //更新旋轉因子 } } } if(isI==-1) for(i=0;i<L;i++) y[i].r/=L; } const int N = 50024; int ans[N<<2]; complex a[N<<2], b[N<<2]; char num1[N], num2[N]; int main() { int i,j; while(~scanf("%s%s",num1,num2)) { memset(ans,0,sizeof(ans)); int len1=strlen(num1),len2=strlen(num2),L=1; int ll=len1+len2-1; while(L<ll) L<<=1; for(i=len1-1,j=0;i>=0;--i,++j) { a[j]=complex(num1[i]-'0',0); //高位實為多項式的低位 } for(i=len2-1,j=0;i>=0;--i,++j) { b[j]=complex(num2[i]-'0',0); } for(i=len1;i<L;++i) a[i]=complex(0,0);//補0 for(i=len2;i<L;++i) b[i]=complex(0,0); FFT(a,L,1); FFT(b,L,1); for(i=0;i<L;++i) { a[i]=a[i]*b[i]; } FFT(a,L,-1); for(i=0;i<L;++i) ans[i]=a[i].r+0.5; for(i=0;i<L;++i) { ans[i+1]+=ans[i]/10; //進位 ans[i]%=10; } int p=L; while(!ans[p] && p) --p; while(p>=0) printf("%d", ans[p--]); puts(""); } return 0; }