1. 程式人生 > >洛谷 4245 【模板】任意模數NTT

洛谷 4245 【模板】任意模數NTT

題目:https://www.luogu.org/problemnew/show/P4245

大概是用3個模數分別做一遍,用中國剩餘定理合併。

前兩個合併起來變成一個 long long 的模數,再要和第三個合併的話就爆 long long ,所以可以用一種讓兩個模數的乘積不出現的方法:https://blog.csdn.net/qq_35950004/article/details/79477797

 x*m1+a1 = -y*m2 + a2  <==>  x*m1+y*m2 = a2-a1  <==>  x*m1 = a2-a1 (mod m2)  <==> x=(a2-a1)*m1^{-1} (mod m2)

然後根據該部落格裡的證明,在mod m2意義下算出來的 x 就是真的 x 。這樣的話答案就是 x*m1+a1 ,可以在快速乘的過程中對題目中給的模數取模,就不會爆 long long 啦。

注意輸入的 a[ ] 和 b[ ] 不能 ntt( ,0, ) 之後再 ntt( ,1, ) 回來,因為值已經模了剛才那個模數了;所以要多開一些陣列。

注意輸入進 mul 裡的 a 和 b 應該是正的,不然沒法 b>>=1 之類的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include
<algorithm> #define ll long long using namespace std; const int N=1e5+5; int m[3]={998244353,1004535809,469762049}; int n0,n1,mod,len,r[N<<2],a[3][N<<2],b[3][N<<2],c[3][N<<2]; ll M=(ll)m[0]*m[1],d[N<<1]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'
9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } void upd(ll &x,ll md){x>=md?x-=md:0;} void upd(int &x,ll md){x>=md?x-=md:0;} ll mul(ll a,ll b,ll md) { a%=md; b%=md;// ll ret=0;while(b){if(b&1ll)ret+=a,upd(ret,md);a+=a;upd(a,md);b>>=1ll;}return ret; } ll pw(ll x,ll k,ll md) {ll ret=1;while(k){if(k&1ll)ret=mul(ret,x,md);x=mul(x,x,md);k>>=1ll;}return ret;} void ntt(int *a,bool fx,int md) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int m=R>>1; int Wn=pw(3,(md-1)/R,md); fx?Wn=pw(Wn,md-2,md):0; for(int i=0;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%md) { int tmp=(ll)w*a[i+m+j]%md; a[i+m+j]=a[i+j]+md-tmp; upd(a[i+m+j],md); a[i+j]=a[i+j]+tmp; upd(a[i+j],md); } } if(!fx)return; int inv=pw(len,md-2,md); for(int i=0;i<len;i++) a[i]=(ll)a[i]*inv%md; } int main() { n0=rdn()+1; n1=rdn()+1; mod=rdn(); for(int i=0;i<n0;i++)a[0][i]=a[1][i]=a[2][i]=rdn(); for(int i=0;i<n1;i++)b[0][i]=b[1][i]=b[2][i]=rdn(); for(len=1;len<=n0+n1;len<<=1); for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0); for(int i=0;i<3;i++)//don't ntt(a,1,m[i]) for it can't return(already mod) { ntt(a[i],0,m[i]); ntt(b[i],0,m[i]); for(int j=0;j<len;j++)c[i][j]=(ll)a[i][j]*b[i][j]%m[i]; ntt(c[i],1,m[i]); } ll inv=pw(m[0],m[1]-2,m[1]),t; for(int i=0,lm=n0+n1-1;i<lm;i++) { t=mul((c[1][i]-c[0][i])%m[1]+m[1],inv,m[1]); d[i]=(mul(t,m[0],M)+c[0][i])%M; } inv=pw(M,m[2]-2,m[2]); for(int i=0,lm=n0+n1-1;i<lm;i++) { t=mul((c[2][i]-d[i])%m[2]+m[2],inv,m[2]); d[i]=(mul(t,M,mod)+d[i])%mod; printf("%lld ",d[i]); } puts(""); return 0; }