1. 程式人生 > >bzoj 3992 [SDOI2015]序列統計——NTT(迴圈卷積&&快速冪)

bzoj 3992 [SDOI2015]序列統計——NTT(迴圈卷積&&快速冪)

題目:https://www.lydsy.com/JudgeOnline/problem.php?id=3992

有轉移次數、模M餘數、方案數三個值,一看就是係數的地方放一個值、指數的地方放一個值、做卷積的次數表示一個值(應該是表示轉移次數)。

可以餘數和方案數都要求相乘,指數只能相加,怎麼辦?

然後看題解,原來可以用M的原根的冪來表示餘數那個資訊!因為原根的幾次冪和%M剩餘類可以一一對應(除了%M==0!!!),所以用原根的冪表示%M餘幾,兩個餘數相乘就變成原根的指數相加了!把該餘數對應的原根的指數放在多項式指數的位置,就可以NTT啦!

原根是 1~P-1 次冪的值%P各不相同的,所以可以用 0次項~M-2次項 或者 1次項~M-1 次項來表示。

n的範圍要求快速冪。但不是把點值拿出來之後對點值快速冪一番再用點值還原回係數,因為每次卷積那個多項式的長度都要翻倍,所以最後n個點的個數就不夠了。

所以要快速冪中每次卷積了一下後把它翻倍的長度手動迴圈一番變回原長M。這樣就行啦!

注意資料範圍!!!求的那個 x 不能為0,而給出的元素可以為0!而原根的那些冪都不會為0!(仔細一想,只有0或M的倍數才會%M==0)考慮到那個 x 不會為0、而數列中放入一個0的話值就變成0了,所以給出0以後要認為沒有那個元素!!!!!

快速冪時,ans的初值應該像1一樣;也就是一個多項式卷積它之後還是該多項式本身。想一想,就是在0次項賦1、其他項賦0即可。

發現>(M<<1)的項的值一定是0;所以迴圈的時候可以直接減掉1個(M-1)而不用模什麼的。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=8005; const ll mod=1004535809;
int n,m,M,pn,s[N],zb[N],pri[N],len,r[N<<2];
int a[N<<2],ans[N<<2];
int rdn()
{
  int ret=0;bool fx=1
;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return fx?ret:-ret; } void upd(int &x,int md){x>=md?x-=md:0;} int pw(int x,int k,int md) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%md;x=(ll)x*x%md;k>>=1;}return ret;} int gtrt() { int k=M-1,tot=0; for(int i=2;i*i<=k;i++) if(k%i==0){pri[++tot]=i;while(k%i==0)k/=i;} if(k>1)pri[++tot]=k; for(int g=2;;g++) { int i; for(i=1;i<=tot;i++) if(pw(g,(M-1)/pri[i],M)==1)break; if(i>tot)return g; } } void ntt(int *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int m=R>>1; int Wn=pw(3,(mod-1)/R,mod); fx?Wn=pw(Wn,mod-2,mod):0; for(int i=0;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%mod) { int tmp=(ll)w*a[i+m+j]%mod; a[i+m+j]=a[i+j]+mod-tmp; upd(a[i+m+j],mod); a[i+j]=a[i+j]+tmp; upd(a[i+j],mod); } } if(!fx)return; int inv=pw(len,mod-2,mod); for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod; } int main() { n=rdn(); M=rdn(); pn=rdn(); m=rdn(); for(int i=1;i<=m;i++)s[i]=rdn(); int rt=gtrt(); for(int i=1,k=rt;i<M;i++,k=k*rt%M) zb[k]=i; len=1; for(;len<=M<<1;len<<=1); for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0); for(int i=1;i<=m;i++)if(s[i])a[zb[s[i]]]=1;////if ans[0]=1;/// while(n) { ntt(a,0); if(n&1) { ntt(ans,0); for(int i=0;i<len;i++)ans[i]=(ll)ans[i]*a[i]%mod; ntt(ans,1); for(int i=1;i<M;i++)//pos which >(M<<1) surely have no value ans[i]+=ans[i+M-1],ans[i+M-1]=0,upd(ans[i],mod); } for(int i=0;i<len;i++)a[i]=(ll)a[i]*a[i]%mod; ntt(a,1); for(int i=1;i<M;i++) a[i]+=a[i+M-1],a[i+M-1]=0,upd(a[i],mod); n>>=1; } printf("%d\n",ans[zb[pn]]); return 0; }