1. 程式人生 > >【演算法學習筆記】快速傅立葉變換

【演算法學習筆記】快速傅立葉變換

# 快速傅立葉變換 快速傅立葉變換(Fast Fourier Transform, FTT)在ACM/OI中最主要的應用是計算多項式乘法。 ## 多項式的係數表示和點值表示 假設$f(x)$為$x$的$n$階多項式,則其可以表示為: $$f(x)=\sum_{i=0}^na_ix^i$$ 這裡的$n+1$個係數$\{a_0,a_1,\cdots,a_n\}$就稱為多項式$f(x)$的係數表示。 另一方面,我們也可以把$f(x)$看成是一個關於$x$的函式,我們可以取$n+1$個不同的$x_i$,用$\{(x_0,f(x_0)),(x_1,f(x_1)),\cdots(x_n,f(x_n))\}$這$n+1$個數值對來唯一確定$f(x)$,這種表示形式就稱為多項式$f(x)$的點值表示。 ## 點值表示與多項式乘法的關係 假設我們現在要求的是$F(x)=f(x)\cdot g(x)$,如果我們已知$f(x)$和$g(x)$的點值表示,那麼我們可以非常容易地得到$F(x)$的點值表示為 $$\{(x_0,f(x_0)g(x_0)),(x_1,f(x_1)g(x_1)),\cdots,(x_n,f(x_n)g(x_n))\}$$ 注意這裡的$n$實際上要取到$f(x)$和$g(x)$的階數之和。 現在的關鍵問題是,如何快速將這一點值表示轉換為係數表示。 ## FFT的實現 為了解決這一問題,我們首先考慮其逆問題,也即:如何從係數表示快速計算點值表示。 ### FFT 暴力計算$n$對點值的總時間複雜度為$O(n^2)$。如何優化呢?我們希望我們選擇的$n$個$x_i$之間存在一定的關係,使得我們可以複用$x_i^k$的計算結果。那麼,應該如何選擇呢? 前人的經驗告訴我們,可以選擇單位復根$\omega_n^i$。它有三個重要的性質: $$\omega_n^n=1$$ $$\omega_n^i=\omega_{2n}^{2i}$$ $$\omega_{2n}^{n+i}=-\omega_{2n}^i$$ 利用上述這三個性質,我們可以實現計算過程的簡化。 不妨考慮一個最高階為7階的多項式 $$f(x)=a_0+a_1x^1+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7$$ 可以把奇偶項分別處理 $$ \begin{aligned} f(x) &=(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6) \\ &=G(x^2)+xH(x^2) \end{aligned} $$ 從而 $$ \text{DFT}(f(x))=\text{DFT}(G(x^2))+x\text{DFT}(H(x^2)) $$ 這時把單位復根$\omega_n^k$($k Code(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)
using namespace std;
typedef complex cd;
const cd I{0, 1};
cd tmp[MAXN], a[MAXN], b[MAXN];
void fft(cd *f, int n, int rev) {
    if (n == 1) return;
    for (int i = 0; i < n; ++i) tmp[i] = f[i];
    for (int i = 0; i < n; ++i) {
        if (i & 1) f[n / 2 + i / 2] = tmp[i];
        else
            f[i / 2] = tmp[i];
    }
    cd *g = f, *h = f + n / 2;
    fft(g, n / 2, rev), fft(h, n / 2, rev);
    cd omega = exp(I * (2 * M_PI / n * rev)), now = 1;
    for (int k = 0; k < n / 2; ++k) {
        tmp[k] = g[k] + now * h[k];
        tmp[k + n / 2] = g[k] - now * h[k];
        now *= omega;
    }
    for (int i = 0; i < n; ++i) f[i] = tmp[i];
}
int main() {
    int n, m;
    cin >> n >> m;
    int k = 1 << (32 - __builtin_clz(n + m + 1));
    for (int i = 0; i <= n; ++i) cin >> a[i];
    for (int j = 0; j <= m; ++j) cin >> b[j];
    fft(a, k, 1);
    fft(b, k, 1);
    for (int i = 0; i < k; ++i) a[i] *= b[i];
    fft(a, k, -1);
    for (int i = 0; i < k; ++i) a[i] /= k;
    for (int i = 0; i < n + m + 1; ++i) cout << (int)round(a[i].real()) << " ";
}
上述遞迴方法的常數較大,不能通過洛谷P3803的最後兩個測試點。 為了改寫非遞迴方法,我們引入蝴蝶變換的概念。 ### 蝴蝶變換 繼續使用前面的例子,經過第一步分治,將原來的係數分為兩組: $$\{a_0,a_2,a_4,a_6\},\{a_1,a_3,a_5,a_7\}$$ 繼續進行第二步分治,得到四組係數: $$\{a_0,a_4\},\{a_2,a_6\},\{a_1,a_5\},\{a_3,a_7\}$$ 最後一步分治,得到八組係數: $$\{a_0\},\{a_4\},\{a_2\},\{a_6\},\{a_1\},\{a_5\},\{a_3\},\{a_7\}$$ 所謂蝴蝶變換,指的就是從${a_0,a_1,\cdots,a_{n-1}}$這一原始係數序列,變換得到最後一步分治後的係數序列。 觀察後可以發現,在蝴蝶變換的最終結果中,係數下標的二進位制表示恰好是其所在位置二進位制表示的逆序,因此,可以利用這一規律來求取蝴蝶變換的結果。 直接利用規律來計算的複雜度是$O(n\log n)$,如果從小到大遞推實現,複雜度則為$O(n)$。 ### FFT的非遞迴實現 下面給出了洛谷P3803的非遞迴實現。
Code(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)
using namespace std;
typedef complex cd;
const cd I{0, 1};
cd a[MAXN], b[MAXN];
void change(cd *f, int n) {
    int i, j, k;
    for (int i = 1, j = n / 2; i < n - 1; i++) {
        if (i < j) swap(f[i], f[j]);
        k = n / 2;
        while (j >= k) {
            j = j - k;
            k = k / 2;
        }
        if (j < k) j += k;
    }
}
void fft(cd *f, int n, int rev) {
    change(f, n);
    for (int len = 2; len <= n; len <<= 1) {
        cd omega = exp(I * (2 * M_PI / len * rev));
        for (int j = 0; j < n; j += len) {
            cd now = 1;
            for (int k = j; k < j + len / 2; ++k) {
                cd g = f[k], h = now * f[k + len / 2];
                f[k] = g + h, f[k + len / 2] = g - h;
                now *= omega;
            }
        }
    }
    if (rev == -1)
        for (int i = 0; i < n; ++i) f[i] /= n;
}
int main() {
    int n, m;
    cin >> n >> m;
    int k = 1 << (32 - __builtin_clz(n + m + 1));
    for (int i = 0; i <= n; ++i) cin >> a[i];
    for (int j = 0; j <= m; ++j) cin >> b[j];
    fft(a, k, 1);
    fft(b, k, 1);
    for (int i = 0; i < k; ++i) a[i] *= b[i];
    fft(a, k, -1);
    for (int i = 0; i < n + m + 1; ++i) cout << (int)round(a[i].real()) << " ";
}
## 學習資源 ### [Matters Computational](https://www.springer.com/gp/book/9783642147630) - 第二十一章 快速傅立葉變換 ## 練習題 裸FFT並不可怕,本身FFT的碼量並不算大,背一背也不是多大的事,關鍵是如何看出一道題目是FFT。 ### [SPOJ - ADAMATCH](https://www.spoj.com/problems/ADAMATCH/) 如果暴力列舉子串,時間複雜度為$O(|r|^2)$,顯然不行。如何降低複雜度呢? #### **提示一** 首先考慮字母`'A'`。不妨把字串為`'A'`的位置設為$1$,其餘位置設為$0$。看起來似乎可以進行多項式乘法,但乘法的結果似乎沒有明顯的意義。 #### 提示二 如果把`r`串逆序呢?看看此時乘積的每一項有怎樣的含義。 #### 參考程式碼(C++) ```cpp #include #include #include #include #include #define MAXN (1 << 22) using namespace std; typedef complex cd; const cd I{0, 1}; cd a[MAXN], b[MAXN]; void change(cd *f, int n) { for (int i = 1, j = n / 2; i < n - 1; i++) { if (i < j) swap(f[i], f[j]); int k = n / 2; while (j >= k) { j = j - k; k = k / 2; } if (j < k) j += k; } } void fft(cd *f, int n, int rev) { change(f, n); for (int len = 2; len <= n; len <<= 1) { cd omega = exp(I * (2 * M_PI / len * rev)); for (int j = 0; j < n; j += len) { cd now = 1; for (int k = j; k < j + len / 2; ++k) { cd g = f[k], h = now * f[k + len / 2]; f[k] = g + h, f[k + len / 2] = g - h; now *= omega; } } } if (rev == -1) for (int i = 0; i < n; ++i) f[i] /= n; } int main() { string s, r; cin >> s >> r; int n = s.size(), m = r.size(); int k = 1 << (32 - __builtin_clz(n + m + 1)); vector cnt(k); for (char c : "ACGT") { memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); for (int i = 0; i < n; ++i) a[i] = s[i] == c; for (int i = 0; i < m; ++i) b[i] = r[m - i - 1] == c; fft(a, k, 1); fft(b, k, 1); for (int i = 0; i < k; ++i) a[i] *= b[i]; fft(a, k, -1); for (int i = 0; i < k; ++i) cnt[i] += (int)round(a[i].real()); } int ans = m; for (int i = m - 1; i < n; ++i) ans = min(ans, m - cnt[i]); cout << ans; } ``` ### [SPOJ - TSUM](https://www.spoj.com/problems/TSUM/) 如果暴力列舉,時間複雜度為$O(n^3)$,顯然不行。如何降低複雜度呢? #### 提示一 加法可以變為多項式的乘法。 #### 提示二 如何去除包含重複元素的項? #### 參考程式碼(C++) ```cpp #include #include #include #include #define MAXN 131072 #define OFFSET 20000 using namespace std; typedef complex cd; const cd I{0, 1}; void change(vector &f, int n) { for (int i = 1, j = n / 2; i < n - 1; i++) { if (i < j) swap(f[i], f[j]); int k = n / 2; while (j >= k) { j = j - k; k = k / 2; } if (j < k) j += k; } } void fft(vector &f, int n, int rev) { change(f, n); for (int len = 2; len <= n; len <<= 1) { cd omega = exp(I * (2 * M_PI / len * rev)); for (int j = 0; j < n; j += len) { cd now = 1; for (int k = j; k < j + len / 2; ++k) { cd g = f[k], h = now * f[k + len / 2]; f[k] = g + h, f[k + len / 2] = g - h; now *= omega; } } } if (rev == -1) for (int i = 0; i < n; ++i) f[i] /= n; } int main() { int n; cin >> n; vector a(MAXN), a2(MAXN); vector a3(MAXN); for (int i = 0; i < n; ++i) { int m; cin >> m; a[m + OFFSET] = cd{1, 0}; a2[(m + OFFSET) << 1] = cd{1, 0}; a3[(m + OFFSET) * 3] = 1; } vector tot(a), b(a); fft(tot, MAXN, 1); fft(b, MAXN, 1); fft(a2, MAXN, 1); for (int i = 0; i < MAXN; ++i) tot[i] *= b[i] * b[i], a2[i] *= b[i]; fft(tot, MAXN, -1); fft(a2, MAXN, -1); for (int i = 0; i < MAXN; ++i) { int cnt1 = round(tot[i].real()); // ABC, with permutation int cnt2 = round(a2[i].real()); // AAB, no permutation int cnt3 = a3[i]; // AAA int cnt = (cnt1 - cnt2 * 3 + cnt3 * 2) / 6; if (cnt > 0) cout << i - OFFSET * 3 << " : " << cnt << endl; } } ```