1. 程式人生 > >淺談快速沃爾什變換

淺談快速沃爾什變換

name puts 算法 多項式 lin 題目 namespace -i pla

快速沃爾什變換(fwt)

\(fwt\)是一種快速計算位運算卷積的算法,一般包括按位或卷積,按位與卷積和異或卷積。

按位或(or)卷積

對於多項式\(A,B,C\),定義\(\oplus\)為卷積符號,即\(A\oplus B = C\)

那麽,按位或卷積就是:
\[ C_k=\sum_{i~or~j=k}A_i\cdot B_j \]
類比於\(FFT\),現在,我們的任務就是找到一種變換,記這種變換為\(fwt(A)\),則要滿足\(fwt(A)\times fwt(B)=fwt(C)\),其中\(\times\)表示每一位相乘,且\(A\oplus B=C\)

經過前人的大力研究,可以發現:

\[ fwt(A)_i=\sum_{j~or~i=i}A_j \]
是滿足性質的,證明很簡單,直接帶進去可得:
\[ \begin{align} fwt(C)_k&=\sum_{j~or~k=k}\sum_{a~or~b=k}A_a\cdot B_b\&=\sum_{a~or~k=k}A_a\cdot \sum_{b~or~k=k}B_b\&=fwt(A)_k\cdot fwt(B)_k \end{align} \]
即得證。

那麽,考慮怎樣快速的進行\(fwt\)變換。

然後有一個這樣的式子:
\[ fwt(A)= \begin{cases} (fwt(A_1),fwt(A_1)+fwt(A_2))&n>0\A_0&n=0 \end{cases} \]


其中,\((A,B)\)表示把兩個多項式的系數拼起來,感性理解一下就好了。

\(A_1\)表示多項式前半段,\(A_2\)表示後半段。

\(n=0\)的時候顯然,我們只需要關心上面那個是為什麽就好了。

對於前半段的第\(i\)項,\(i\)的最高位肯定是\(0\),那麽後半段顯然對他沒有影響,前半段的影響就是\(fwt(A_1)_i\)

對於後半段的第\(i\)項,\(i\)的最高位是\(1\),所以最高位取\(0\)時是\(fwt(A_1)_i\),取\(1\)時是\(fwt(A_2)_i\),所以一共就是\(fwt(A_1)+fwt(A_2)\)

然後這玩意形式其實和\(FFT\)差不太多,復雜度也是\(O(n\log n)\)

代碼:

void fwt_or(int *r) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
                r[i+j+k]=(r[i+j+k]+r[j+k])%mod;
}

按位與(and)卷積

和上面差不多的,定義:
\[ fwt(A)_i=\sum_{j\&i=i}A_i \]
證明也差不多,這裏不贅述了。

那麽,算的話就是:
\[ fwt(A)= \begin{cases} (fwt(A_1)+fwt(A_2),fwt(A_2))&n>0\A_0&n=0 \end{cases} \]
只要考慮按位與的性質,高位為\(1\)時只能選高位為\(1\)的,否則都能選。

代碼:

void fwt_and(int *r) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
                r[j+k]=(r[i+j+k]+r[j+k])%mod;
}

異或(xor)卷積

這裏的定義就不是很相同了。

定義:
\[ fwt(A)_i=\sum_{j=0}^{n}(-1)^{cnt(i\&j)}A_j \]
其中,\(i\&j\)表示按位與,\(cnt(x)\)表示\(x\)二進制下\(1\)的個數。(這到底是怎麽想到的。。)

帶進去交換下枚舉順序可得:
\[ \begin{align} fwt(C)_i&=\sum_{j=0}^{n}(-1)^{cnt(i\&j)}C_j\&=\sum_{j=0}^{n}(-1)^{cnt(i\&j)}\sum_{a\oplus b=j}A_aB_b\&=\sum_{a=0}^{n}A_a\sum_{b=0}^{n}B_b(-1)^{cnt(i\&(a\oplus b))} \end{align} \]
我們考慮下指數上的那一塊東西:\(cnt(i\&(a\oplus b))\),分情況討論下這個與\(cnt(i\&a)+cnt(i\&b)\)的關系:(由於多位和一位沒有區別,這裏只討論一位)

\(i\)\(0\),顯然這一位不計入答案,不管。

\(a,b\)都為\(1\)的話,\(a\oplus b=0\),不計入答案,但是註意到這裏是\((-1)\)的指數,其實\((-1)^0=(-1)^2\),不妨看做是\(2\),那麽這兩個相等。

\(a,b\)有一個為\(1\),前後顯然相等,都為\(1\)

\(a,b\)都為\(0\),顯然也相等,都為\(0\)

所以式子可以改寫成這樣:
\[ \begin{align} fwt(C)_i&=\sum_{a=0}^{n}A_a\sum_{b=0}^{n}B_b(-1)^{cnt(i\&a)+cnt(i\&b)}\&=\sum_{a=0}^{n}(-1)^{cnt(i\&a)}A_a\sum_{b=0}^{n}(-1)^{cnt(i\&b)}B_b\&=fwt(A)_i\cdot fwt(B)_i \end{align} \]
所以,證畢。

那麽,快速做這個的式子:
\[ fwt(A)= \begin{cases} (fwt(A_1)+fwt(A_2),fwt(A_1)-fwt(A_2))&n>0\A_0&n=0 \end{cases} \]
具體的,考慮前一半的時候,最高位為\(0\),直接加起來就好了。

對於後一半,最高位為\(1\),如果選的數最高位也為\(1\)\(cnt\)就多了\(1\),也就是整體多乘了個\(-1\),所以就是\(fwt(A_1)-fwt(A_2)\)

代碼:

void fwt_xor(int *r) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++) {
                int x=r[j+k],y=r[i+j+k];
                r[j+k]=(x+y)%mod,r[i+j+k]=(x-y)%mod;
            }
}

逆沃爾什變換

知道了上面的,這玩意其實就很簡單了。

對於按位或,就是知道了\(fwt(A_1)\)\(fwt(A_1)+fwt(A_2)\),求出兩個分別是多少,直接減一下就完了:
\[ ifwt(A)=(ifwt(A_1),ifwt(A_2)-ifwt(A_1)) \]
對於按位與,也差不多:
\[ ifwt(A)=(ifwt(A_1)-ifwt(A_2),ifwt(A_2)) \]
對於異或,是知道\(fwt(A_1)+fwt(A_2)\)\(fwt(A_1)-fwt(A_2)\),那麽加起來除以\(2\)就是第一個,減一下除以\(2\)就是第二個,即:
\[ ifwt(A)=(\frac{ifwt(A_1)+ifwt(A_2)}{2},\frac{ifwt(A_1)-ifwt(A_2)}{2}) \]

模板

給一個模板大全吧,題目來自luogu4717。

#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 2e5+10;
const int mod = 998244353;
const int inv2 = 499122177;

int bit,n,a[maxn],b[maxn],c[maxn],ina[maxn],inb[maxn];

void fwt_or(int *r,int op) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
                if(op==1) r[i+j+k]=(r[i+j+k]+r[j+k])%mod;
                else r[i+j+k]=(r[i+j+k]-r[j+k])%mod;
}

void fwt_and(int *r,int op) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
                if(op==1) r[j+k]=(r[i+j+k]+r[j+k])%mod;
                else r[j+k]=(r[j+k]-r[i+j+k])%mod;
}

void fwt_xor(int *r,int op) {
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++) {
                int x=r[j+k],y=r[i+j+k];
                if(op==1) r[j+k]=(x+y)%mod,r[i+j+k]=(x-y)%mod;
                else r[j+k]=1ll*(x+y)*inv2%mod,r[i+j+k]=1ll*(x-y)*inv2%mod;
            }
}

int main() {
    read(bit);n=1<<bit;
    for(int i=0;i<n;i++) read(ina[i]);
    for(int i=0;i<n;i++) read(inb[i]);
    // or
    memcpy(a,ina,sizeof ina);memcpy(b,inb,sizeof inb);
    fwt_or(a,1),fwt_or(b,1);for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
    fwt_or(a,-1);for(int i=0;i<n;i++) printf("%d ",(a[i]+mod)%mod);puts("");
    // and 
    memcpy(a,ina,sizeof ina);memcpy(b,inb,sizeof inb);
    fwt_and(a,1),fwt_and(b,1);for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
    fwt_and(a,-1);for(int i=0;i<n;i++) printf("%d ",(a[i]+mod)%mod);puts("");
    // xor
    memcpy(a,ina,sizeof ina);memcpy(b,inb,sizeof inb);
    fwt_xor(a,1),fwt_xor(b,1);for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
    fwt_xor(a,-1);for(int i=0;i<n;i++) printf("%d ",(a[i]+mod)%mod);puts("");
    return 0;
}

淺談快速沃爾什變換