1. 程式人生 > >【模板篇】NTT和三模數NTT

【模板篇】NTT和三模數NTT

就會 CP 可能 復數 inline gin Go algo +=

之前寫過FFT的筆記. 我們知道FFT是在復數域上進行的變換.
而且經過數學家的證明, DFT是復數域上唯一滿足循環卷積性質的變換.

而我們在OI中, 經常遇到對xxxx取模的題目, 這就啟發我們可不可以在模運算的意義下找一個這樣的變換.
然後我們發現有個神奇的東西, 原根\(g\), 這東西在模意義下相當於單位復根\(-e^{\frac{2\pi i}{n}}\).

所以我們預處理一下\(g\)的冪和逆元, 然後改一下fft的代碼就出現了快速數論變換ntt
懶得寫了 直接上代碼:

void getwn(){ //預處理原根的冪和逆元
    int x=qpow(3,p-2);
    for(int
i=0;i<20;++i){ wn[i]=qpow(3,(p-1)/(1<<i)); inv[i]=qpow(x,(p-1)/(1<<i)); } } void ntt(int *y,bool f){ rev(y); //翻轉代碼和fft無異 for(int m=2,id=1;m<=n;m<<=1,++id){ //id用來記錄轉到第幾下了 for(int k=0;k<n;k+=m){ int w=1,wm=f?wn[id]:inv[id]; //如果是dft就用冪, idft就用冪的逆元
for(int j=0;j<m/2;++j){ //這裏跟fft一樣, 不過要對p取模 int u=y[k+j]%p,t=1ll*w*y[k+j+m/2]%p; y[k+j]=u+t; if(y[k+j]>p) y[k+j]-=p; y[k+j+m/2]=u-t; if(y[k+j+m/2]<0) y[k+j+m/2]+=p; w=1ll*w*wm%p; } } } if
(!f){ int x=qpow(n,p-2); for(int i=0;i<n;++i) y[i]=1ll*y[i]*x%p; } }

好像差不多呢~ 不過這樣就要求我們找一個原根好求的數. 比如著名的uoj數: 998244353 還有1004535809和469762049等, 這三個數原根都是3~
好像因為當時一看到模數不是1e9+7一般就會想到ntt, vfk為了防止這一點, 模數統一采用998244353, 現在看看收效不錯.

不過 有些喪心病狂的人就是要用1e9+7作為ntt的模數, 甚至還出現了可以不模質數的情況!
那我們怎麽解決任意模數ntt呢? 我們可以采用拆系數ntt或者三模數ntt. 這裏介紹一下三模數ntt.
對於一般的數據範圍, \(n\leq10^5, a_i\leq10^9\), 這樣可能會到10^5*10^{9^2}=10^{23}級別.
所以我們可以找三個乘積\(>10^{23}\)的ntt-friendly的數, 然後分別ntt再想辦法合並.
我們假如答案是ans, 那我們做三次ntt後就能得到如下三個柿子.
\[ \left\{\begin{matrix} ans\equiv a_1(\mod m_1)\\ ans\equiv a_2(\mod m_2)\\ ans\equiv a_3(\mod m_3) \end{matrix}\right. \]
我們把前兩個柿子通過中國剩余定理合並, 就可以得到
\[ \left\{\begin{matrix} ans\equiv A(\mod M)\\ ans\equiv a_3(\mod m_3) \end{matrix}\right. \]
其中, \(M=m_1*m_2\)
這樣我們設\(ans=kM+A\),
\[ kM+A\equiv a_3(\mod m_3) \k=(a_3-A)*M^{-1} (\mod m_3) \]
這樣我們求出\(k\)然後代回到\(ans=kM+A\)就可以求對任意模數取模的結果了.

中國剩余定理合並的時候直接乘是可以爆long long的, 所以我們要用到\(O(1)\)快速乘~

下面上一波代碼: luogu4245 【模板】MTT
哎呀覺得自己碼風有點醜啊qwq

#include <cstdio>
#include <cstring>
#include <algorithm>
typedef long long LL;
const int N=600020,p0=469762049,p1=998244353,p2=1004535809;
const LL M=1ll*p0*p1;
int wn[20],nw[20],rev[N],n,lg,p;
int qpow(int a,int b,int p,int s=1){
    for(;b;b>>=1,a=1ll*a*a%p)
        if(b&1) s=1ll*s*a%p;
    return s;
}
LL mul(LL a,LL b,LL p){ a%=p; b%=p;
    return (a*b-(LL)((long double)a*b/p)*p+p)%p;
}
void calcw(int p){
    int x=qpow(3,p-2,p);
    for(int i=0;i<20;++i){
        wn[i]=qpow(3,(p-1)/(1<<i),p);
        nw[i]=qpow(x,(p-1)/(1<<i),p);
    }
}
void init(){
    for(int i=0;i<n;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<lg);
}
void ntt(int *y,bool f,int p){ calcw(p);
    for(int i=0;i<n;++i) if(i<rev[i]) std::swap(y[i],y[rev[i]]);
    for(int m=2,id=1;m<=n;m<<=1,++id){
        for(int k=0;k<n;k+=m){
            int w=1,wm=f?wn[id]:nw[id];
            for(int j=0;j<m>>1;++j){
                int &a=y[k+j]; int &b=y[k+j+m/2];
                int u=a%p,t=1ll*w*b%p;
                a=u+t; if(a>p) a-=p;
                b=u-t; if(b<0) b+=p;
                w=1ll*w*wm%p;
            }
        }
    } int x=qpow(n,p-2,p);
    if(!f) for(int i=0;i<n;++i) y[i]=1ll*y[i]*x%p;
}
char c1[N],c2[N]; int a[N],b[N],c[N],d[N],ans[3][N];
int main(){
    int l1,l2; scanf("%d%d%d",&l1,&l2,&p);
    for(int i=0;i<=l1;++i) scanf("%d",&a[i]),a[i]%=p;
    for(int i=0;i<=l2;++i) scanf("%d",&b[i]),b[i]%=p;
    for(n=1;n<l1||n<l2;n<<=1,++lg); n<<=1; init();
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p0); ntt(d,1,p0);
    for(int i=0;i<n;++i) ans[0][i]=1ll*c[i]*d[i]%p0;
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p1); ntt(d,1,p1);
    for(int i=0;i<n;++i) ans[1][i]=1ll*c[i]*d[i]%p1;
    std::copy(a,a+n,c); std::copy(b,b+n,d);
    ntt(c,1,p2); ntt(d,1,p2);   
    for(int i=0;i<n;++i) ans[2][i]=1ll*c[i]*d[i]%p2;
    ntt(ans[0],0,p0); ntt(ans[1],0,p1); ntt(ans[2],0,p2);
    for(int i=0;i<n;++i){
        LL A=mul(1ll*ans[0][i]*p1%M,qpow(p1%p0,p0-2,p0),M)
            +mul(1ll*ans[1][i]*p0%M,qpow(p0%p1,p1-2,p1),M);
        if(A>M) A-=M;
        LL k=((ans[2][i]-A)%p2+p2)%p2*qpow(M%p2,p2-2,p2)%p2;
        a[i]=1ll*(k%p)*(M%p)%p+A%p;
        if(a[i]>p) a[i]-=p;
    }
    for(int i=0;i<=l1+l2;++i) printf("%d ",a[i]);
}

【模板篇】NTT和三模數NTT