1. 程式人生 > >小學生都能看懂的FFT!!!

小學生都能看懂的FFT!!!

long 不同 def 如果能 裏的 文件 補充 運算 其余

小學生都能看懂的FFT!!!


前言

在創新實踐重心偷偷看了一天FFT資料後,我終於看懂了一點。為了給大家提供一份簡單易懂的學習資料,同時也方便自己以後復習,我決定動手寫這份學習筆記。

食用指南:
本篇受眾:如標題所示,另外也面向同我一樣高中起步且非常菜的OIer。真正的dalao請無視。
本篇目標:讓大家(和不知道什麽時候把FFT忘了的我)在沒有數學基礎的情況下,以最快的速度了解並 會寫 FFT。因此本篇將采用盡可能通俗易懂的語言,且略過大部分數學證明,在嚴謹性上可能有欠缺。但如果您發現了較大的邏輯漏洞,歡迎在評論裏指正!

最後……來個版權聲明吧。本文作者胡小兔,博客地址http://rabbithu.cnblogs.com。暫未許可在任何其他平臺轉載。

你一定聽說過FFT,它的高逼格名字讓人望而卻步——“快速傅裏葉變換”。
你可能知道它可以\(O(n \log n)\)求高精度乘法,你想學,可是面對一堆公式,你無從下手。
那麽歡迎閱讀這篇教程!

[Warning] 本文涉及復數(虛數)的一小部分內容,這可能是最難的部分,但只要看下去也不是非常難,請不要看到它就中途退出啊QAQ。

什麽是FFT?

快速傅裏葉變換(FFT)是一種能在\(O(n \log n)\)的時間內將一個多項式轉換成它的點值表示的算法。

補充資料:什麽是點值表示

\(A(x)\)是一個\(n - 1\)次多項式,那麽把\(n\)個不同的\(x\)代入,會得到\(n\)

\(y\)\(n\)\((x, y)\)唯一確定了該多項式,即只有一個多項式能同時滿足“代入這些\(x\),得到的分別是這些\(y\)”。
由多項式可以求出其點值表示,而由點值表示也可以求出多項式。

(並不想證明,十分想看證明的同學請前往“參考資料”部分)。

註:下文如果不加特殊說明,默認所有\(n\)為2的整數次冪。如果一個多項式次數不是\(n\)的整數次冪,可以在後面補0。

為什麽要使用FFT?

FFT可以用來加速多項式乘法(平時非常常用的高精度大整數乘法就是最終把\(x = 10\)代入的多項式乘法)。

假設有兩個\(n-1\)次多項式\(A(x)\)\(B(x)\),我們的目標是——把它們乘起來。

普通的多項式乘法是\(O(n^2)\)的——我們要枚舉\(A(x)\)中的每一項,分別與\(B(x)\)中的每一項相乘,來得到一個新的多項式\(C(x)\)

但有趣的是,兩個用點值表示的多項式相乘,復雜度是\(O(n)\)的!具體方法:\(C(x_i) = A(x_i) \times B(x_i)\),所以\(O(n)\)枚舉\(x_i\)即可。

要是我們把兩個多項式轉換成點值表示,再相乘,再把新的點值表示轉換成多項式豈不就可以\(O(n)\)解決多項式乘法了!

……很遺憾,顯然,把多項式轉換成點值表示的樸素算法是\(O(n^2)\)的。另外,即使你可能不會——把點值表示轉換為多項式的樸素“插值算法”也是\(O(n^2)\)的。

難道大整數乘法就只能是\(O(n^2)\)嗎?!不甘心的同學可以發現,大整數乘法復雜度的瓶頸可能在“多項式轉換成點值表示”這一步(以及其反向操作),只要完成這一步就可以\(O(n)\)求答案了。如果能優化這一步,豈不美哉?

傅裏葉:這個我會!

離散傅裏葉變換(快速傅裏葉變換的樸素版)

傅裏葉發明了一種辦法:規定點值表示中的\(n\)\(x\)\(n\)個模長為\(1\)復數

——等等,先別看到復數就走!

補充資料:什麽是復數

如果你學過復數,這段不用看了;
如果你學過向量,請把復數理解成一個向量;
如果你啥都沒學過,請把復數理解成一個平面直角坐標系上的點。

復數具有一個實部和一個虛部,正如一個向量(或點)有一個橫坐標和一個縱坐標。例如復數\(3 + 2i\),實部是\(3\),虛部是\(2\)\(i = \sqrt{-1}\)。可以把它想象成向量\((3, 2)\)或點\((3, 2)\)

但復數比一個向量或點更妙的地方在於——復數也是一種數,它可以像我們熟悉的實數那樣進行加減乘除等運算,還可以代入多項式\(A(x)\)——顯然你不能把一個向量或點作為\(x\)代入進去。

復數相乘的規則:模長相乘,幅角相加。模長就是這個向量的模長(或是這個點到原點的距離);幅角就是x軸正方向逆時針旋轉到與這個向量共線所途徑的角(或是原點出發、指向x軸正方向的射線逆時針旋轉至過這個點所經過的角)。想學會FFT,“模長相乘”暫時不需要了解過多,但“幅角相加”需要記住。

C++的STL提供了復數模板!
頭文件:#include <complex>
定義: complex<double> x;
運算:直接使用加減乘除。

傅裏葉要用到的\(n\)個復數,不是隨機找的,而是——把單位圓(圓心為原點、1為半徑的圓)\(n\)等分,取這\(n\)個點(或點表示的向量)所表示的虛數,即分別以這\(n\)個點的橫坐標為實部、縱坐標為虛部,所構成的虛數。

技術分享圖片

從點\((1, 0)\)開始(顯然這個點是我們要取的點之一),逆時針將這\(n\)個點從\(0\)開始編號,第\(k\)個點對應的虛數記作\(\omega_n^k\)(根據復數相乘時模長相乘幅角相加可以看出,\(\omega_n^k\)\(\omega_n^1\)\(k\)次方,所以\(\omega_n^1\)被稱為\(n\)單位根)。

根據每個復數的幅角,可以計算出所對應的點/向量。\(\omega_n^k\)對應的點/向量是\((\cos \frac{k}{n}2\pi, \sin \frac{k}{n}2\pi)\),也就是說這個復數是\(\cos \frac{k}{n}2\pi + i\sin \frac{k}{n}2\pi\)

傅裏葉說:把\(n\)個復數\(\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n-1}\)代入多項式,能得到一種特殊的點值表示,這種點值表示就叫離散傅裏葉變換吧!

[Warning] 從現在開始,本文個別部分會集中出現數學公式,但是都不是很難,公式恐懼癥患者請堅持!Stay Determined!

補充資料:單位根的性質

性質一:\(\omega_{2n}^{2k} = \omega_{n}^{k}\)

證明:它們對應的點/向量是相同的。

性質二:\(\omega_{n}^{k + \frac{n}{2}} = -\omega_{n}^{k}\)

證明:它們對應的點是關於原點對稱的(對應的向量是等大反向的)。

為什麽要使用單位根作為\(x\)代入

當然是因為離散傅裏葉變換有著特殊的性質啦。

[Warning] 下面有一些證明,如果不想看,請跳到加粗的“一個結論”部分。

\((y_0, y_1, y_2, ..., y_{n - 1})\)為多項式\(A(x) = a_0 + a_1x + a_2x^2 +...+a_{n-1}x^{n-1}\)的離散傅裏葉變換。

現在我們再設一個多項式\(B(x) = y_0 + y_1x + y_2x^2 +...+y_{n-1}x^{n-1}\),現在我們把上面的\(n\)個單位根的倒數,即\(\omega_{n}^{0}, \omega_{n}^{-1}, \omega_{n}^{-2}, ..., \omega_{n}^{-(n - 1)}\)作為\(x\)代入\(B(x)\), 得到一個新的離散傅裏葉變換\((z_0, z_1, z_2, ..., z_{n - 1}\))。

\[ \begin{align*} z_k &= \sum_{i = 0}^{n - 1} y_i(\omega_n^{-k})^i \\ &= \sum_{i = 0}^{n - 1}(\sum_{j = 0}^{n - 1} a_j(\omega_n^i)^j)(\omega_n^{-k})^i \ &= \sum_{j = 0}^{n - 1}a_j(\sum_{i = 0}^{n - 1}(\omega_n^{j - k})^i) \end{align*} \]

這個\(\sum_{i = 0}^{n - 1}(\omega_n^{j - k})^i\)是可求的:當\(j - k = 0\)時,它等於\(n\); 其余時候,通過等比數列求和可知它等於\(\frac{(\omega_n^{j - k})^n - 1}{\omega_n^{j - k} - 1}\)

那麽,\(z_k\)就等於\(na_k\), 即:
\[a_i = \frac{c_i}{n}\]

一個結論

把多項式\(A(x)\)的離散傅裏葉變換結果作為另一個多項式\(B(x)\)的系數,取單位根的倒數即\(\omega_{n}^{0}, \omega_{n}^{-1}, \omega_{n}^{-2}, ..., \omega_{n}^{-(n - 1)}\)作為\(x\)代入\(B(x)\),得到的每個數再除以n,得到的是\(A(x)\)的各項系數。這實現了傅裏葉變換的逆變換——把點值表示轉換成多項式系數表示,這就是離散傅裏葉變換神奇的特殊性質。

快速傅裏葉變換

雖然傅裏葉發明了神奇的變換,能把多項式轉換成點值表示又轉換回來,但是……它仍然是暴力代入的做法,復雜度仍然是\(O(n^2)\)啊!(傅裏葉:我都沒見過計算機,我幹啥要優化復雜度……)

於是,快速傅裏葉變換應運而生。它是一種分治的傅裏葉變換。

[Warning] 下面有較多公式。看起來很嚇人,但是並不復雜。請堅持看完。

快速傅裏葉變換的數學證明

仍然,我們設\(A(x) = a_0 + a_1x + a_2x^2 +...+a_{n-1}x^{n-1}\),現在為了求離散傅裏葉變換,要把一個\(x = \omega_n^k\)代入。

考慮將\(A(x)\)的每一項按照下標的奇偶分成兩部分:

\[A(x) = (a_0 + a_2x^2 + ... + a_{n - 2}x^{n - 2}) + (a_1x + a_3x^3 + ... + a_{n-1}x^{n-1})\]

設兩個多項式:

\[A_1(x) = a_0 + a_2x + ... + a_{n - 2}x^{\frac{n}{2} - 1}\]
\[A_2(x) = a_1 + a_3x + ... + a_{n - 1}x^{\frac{n}{2} - 1}\]

則:

\[A(x) = A_1(x^2) + xA_2(x^2)\]

假設\(k < \frac{n}{2}\),現在要把\(x = \omega_n^k\)代入:

\[\begin{align*} A(\omega_n^k) &= A_1(\omega_n^{2k}) + \omega_n^kA_2(\omega_n^{2k}) \&= A_2(\omega_{\frac{n}{2}}^{k}) + \omega_n^kA_2(\omega_{\frac{n}{2}}^{k}) \end{align*}\]

那麽對於\(A(\omega_n^{k + \frac{n}{2}})\)

\[\begin{align*} A(\omega_n^{k + \frac{n}{2}}) &= A_1(\omega_n^{2k + n}) + \omega_n^{k + \frac{n}{2}}A_2(\omega_n^{2k + n}) \&= A_2(\omega_{\frac{n}{2}}^{k} \times \omega_n^n) + \omega_n^{k + \frac{n}{2}} A_2(\omega_{\frac{n}{2}}^{k} \times \omega_n^n) \&= A_2(\omega_{\frac{n}{2}}^{k}) - \omega_n^kA_2(\omega_{\frac{n}{2}}^{k}) \end{align*}\]

所以,如果我們知道兩個多項式\(A_1(x)\)\(A_2(x)\)分別在\((\omega_{\frac{n}{2}}^{0}, \omega_{\frac{n}{2}}^{1}, \omega_{\frac{n}{2}}^{2}, ... , \omega_{\frac{n}{2}}^{\frac{n}{2} - 1}\))的點值表示,就可以\(O(n)\)求出\(A(x)\)\(\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n-1}\)處的點值表示了。而\(A_1(x)\)\(A_2(x)\)都是規模縮小了一半的子問題。分治邊界是\(n = 1\),此時直接return。

快速傅裏葉變換的實現

寫個遞歸就可以實現一個FFT了!

cp omega(int n, int k){
    return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}
void fft(cp  *a, int n, bool inv){
    if(n == 1) return;
    static cp buf[N];
    int m = n / 2;
    for(int i = 0; i < m; i++){ //將每一項按照奇偶分為兩組
        buf[i] = a[2 * i];
        buf[i + m] = a[2 * i + 1];
    }
    for(int i = 0; i < n; i++)
        a[i] = buf[i];
    fft(a, m, inv); //遞歸處理兩個子問題
    fft(a + m, m, inv);
    for(int i = 0; i < m; i++){ //枚舉x,計算A(x)
        cp x = omega(n, i); 
        if(inv) x = conj(x); 
        //conj是一個自帶的求共軛復數的函數,精度較高。當復數模為1時,共軛復數等於倒數
        buf[i] = a[i] + x * a[i + m]; //根據之前推出的結論計算
        buf[i + m] = a[i] - x * a[i + m];
    }
    for(int i = 0; i < n; i++)
        a[i] = buf[i];
}

inv表示這次用的單位根是否要取倒數。

至此你已經會寫fft了!但是這個fft還是1.0版本,比較慢(可能同時還比較長?),親測可能會比加了一些優化的fft慢了4倍左右……

那麽我們來學習一些優化吧!

優化fft

非遞歸fft

在進行fft時,我們要把各個系數不斷分組並放到兩側,那麽一個系數原來的位置和最終的位置有什麽規律呢?

初始位置:0 1 2 3 4 5 6 7
第一輪後:0 2 4 6|1 3 5 7
第二輪後:0 4|2 6|1 5|3 7
第三輪後:0|4|2|6|1|5|3|7

“|”代表分組界限。

可以發現(這你都能發現?),一個位置a上的數,最後所在的位置是“a二進制翻轉得到的數”,例如6(011)最後到了3(110),1(001)最後到了4(100)。

那麽我們可以據此寫出非遞歸版本fft:先把每個數放到最後的位置上,然後不斷向上還原,同時求出點值表示。

代碼:

cp a[N], b[N], omg[N], inv[N];

void init(){
    for(int i = 0; i < n; i++){
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}
void fft(cp *a, cp *omg){
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++){
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
    }
    static cp buf[N];
    for(int l = 2; l <= n; l *= 2){
        int m = l / 2;
        for(int j = 0; j < n; j += l)
            for(int i = 0; i < m; i++){
                buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
                buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
        }
        for(int j = 0; j < n; j++)
            a[j] = buf[j];
    }
}

可以預處理\(\omega_n^k\)\(\omega_n^{-k}\),分別存在omg和inv數組中。調用fft時,如果無需取倒數,則傳入omg;如果需要取倒數,則傳入inv。

蝴蝶操作

這個優化有著一個高大上的名字——“蝴蝶操作”。我第一次看到這個名字時就嚇跑了——尤其是看到那種帶示意圖的蝴蝶操作解說時。

但是你完全無需跑!這是一個很簡單的優化,它可以丟掉上面代碼裏的那個buf數組。

我們為什麽需要buf數組?因為我們要做這兩件事:

a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] + omg[n / l * i] * a[j + i + m]

但是我們又要求這兩行不能互相影響,所以我們需要buf數組。

但是如果我們這樣寫:

cp t = omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - t
a[j + i] = a[j + i] + t

就可以原地進行了,不需要buf數組。

cp a[N], b[N], omg[N], inv[N];

void init(){
    for(int i = 0; i < n; i++){
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}
void fft(cp *a, cp *omg){
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++){
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
    }
    for(int l = 2; l <= n; l *= 2){
    int m = l / 2;
    for(cp *p = a; p != a + n; p += l)
        for(int i = 0; i < m; i++){
            cp t = omg[n / l * i] * p[i + m];
            p[i + m] = p[i] - t;
            p[i] += t;
        }
    }
}

現在,這個fft就比之前的遞歸版快很多了!


到此為止我的FFT筆記就整理完啦。

下面貼一個FFT加速高精度乘法的代碼:

#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <complex>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
    if(c == '-') op = 1;
        x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
        x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar('0' + x % 10);
}
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;
char sa[N], sb[N];
int n = 1, lena, lenb, res[N];
cp a[N], b[N], omg[N], inv[N];
void init(){
    for(int i = 0; i < n; i++){
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}
void fft(cp *a, cp *omg){
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++){
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]); // i < t 的限制使得每對點只被交換一次(否則交換兩次相當於沒交換)
    }
    for(int l = 2; l <= n; l *= 2){
        int m = l / 2;
    for(cp *p = a; p != a + n; p += l)
        for(int i = 0; i < m; i++){
            cp t = omg[n / l * i] * p[i + m];
            p[i + m] = p[i] - t;
            p[i] += t;
        }
    }
}
int main(){
    scanf("%s%s", sa, sb);
    lena = strlen(sa), lenb = strlen(sb);
    while(n < lena + lenb) n *= 2;
    for(int i = 0; i < lena; i++)
        a[i].real(sa[lena - 1 - i] - '0');
    for(int i = 0; i < lenb; i++)
        b[i].real(sb[lenb - 1 - i] - '0');
    init();
    fft(a, omg);
    fft(b, omg);
    for(int i = 0; i < n; i++)
        a[i] *= b[i];
    fft(a, inv);
    for(int i = 0; i < n; i++){
        res[i] += floor(a[i].real() / n + 0.5);
        res[i + 1] += res[i] / 10;
        res[i] %= 10;
    }
    for(int i = res[lena + lenb - 1] ? lena + lenb - 1: lena + lenb - 2; i >= 0; i--)
        putchar('0' + res[i]);
    enter;
    return 0;
}

小學生都能看懂的FFT!!!