1. 程式人生 > >BZOJ.3992.[SDOI2015]序列統計(DP NTT 原根)

BZOJ.3992.[SDOI2015]序列統計(DP NTT 原根)

題目連結

\(Description\)

給定\(n,m,x\)和集合\(S\)。求\(\prod_{i=1}^na_i\equiv x\ (mod\ m)\)的方案數。其中\(a_i\in S\)
\(n\leq10^9,3\leq m\leq 8000且m是質數,1\leq x\leq m-1\)

\(Solution\)

\(f_{i,j}\)表示當前選了\(i\)個數,乘積模\(m\)\(j\)的方案數,\(g_i=[i\in S]\)
轉移就是,\[f_{i,a*b\%m}=\sum f_{i-1,a}*g_b\]
每次轉移是一樣的,所以可以快速冪計算,即\(f_{2i,a*b\%m}=\sum f_{i,a}*f_{i,b}\)


雖然把\(n\)優化到了\(\log n\),但這樣轉移複雜度還是\(O(m^2)\)的。

我們發現,只要能把\(a*b\%m\)寫成\((a+b)\%m\),就是一個迴圈卷積的形式了。
把乘法變成加法可以想到取對數,同樣在模意義下可以用離散對數

\(m\)的一個原根\(g\)\(g,g^2,...,g^{m-1}\)在模\(m\)意義下互不相同。所以我們可以用滿足\(g^A\equiv a\ (mod\ m)\)的正整數\(A\)來替換掉\(a\)(即\(a\)\(m\)的指標\(I(a)\)),它是唯一的。
那麼\(a*b\equiv g^A*g^B\equiv g^{A+B}\equiv g^{(A+B)\%\varphi(m)}\ (mod\ m)\)


所以轉移就成了:\[f_{i,(A+B)\%(m-1)}=\sum f_{i-1,A}*g_B\]

\(g_{A+B}=\sum f_{i-1,A}*g_B\),那麼\(f_{i,j\%(m-1)}=g_j+g_{j+m-1}\)
可以用\(NTT\)優化。

同樣,每一次的轉移還是一樣的,依舊可以用多項式快速冪。
複雜度\(O(m\log m\log n)\)

當然取\(g^0,g^1,...,g^{m-2}\)也行,因為多項式下標以\(0\)開始方便些。
這樣如果集合裡有\(0\)要特判忽略掉它。
因為是迴圈卷積所以每次快速冪乘的時候都需要把\(f\)求出來,也就是一定要再\(NTT\)

回係數表示,不能一直用點值表示做。(是這樣吧?)

//1876kb    3684ms(怎麼這麼慢啊...)
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#define gc() getchar()
#define G 3
#define invG 334845270
#define mod 1004535809
#define Mod(x) x>=mod&&(x-=mod)
#define Add(x,v) (x+=v)>=mod&&(x-=mod)
#define Mul(x,y) (1ll*x*y%mod)
typedef long long LL;
const int N=32500;

int inv,I[N],x[N],t[N],rev[N];

inline int read()
{
    int now=0;register char c=gc();
    for(;!isdigit(c);c=gc());
    for(;isdigit(c);now=now*10+c-'0',c=gc());
    return now;
}
inline int FP(int x,int k,int p)
{
    int t=1;
    for(; k; k>>=1,x=1ll*x*x%p)
        if(k&1) t=1ll*t*x%p;
    return t;
}
int Get_root(int P)
{
    static int p[10005];
    int cnt=0,t=P-1;
    for(int i=2; i*i<=t; ++i)
        if(!(t%i))
        {
            p[++cnt]=i;
            while(!(t%i)) t/=i;
        }
    if(t>1) p[++cnt]=t;
    for(int x=2; ; ++x)
    {
        bool ok=1;
        for(int i=1; i<=cnt; ++i) if(FP(x,(P-1)/p[i],P)==1) {ok=0; break;}
        if(ok) return x;
    }
    return 1;
}
void Pre(int m)
{
    int g=Get_root(m);
    for(int i=0,pw=1; i<m-1; ++i) I[pw]=i, pw=1ll*pw*g%m;//g^i=pw -> I(pw)=i
}
void NTT(int *a,int lim,int opt)
{
    for(int i=1; i<lim; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
    for(int i=2; i<=lim; i<<=1)
    {
        int mid=i>>1,Wn=FP(~opt?G:invG,(mod-1)/i,mod);
        for(int j=0; j<lim; j+=i)
            for(int k=j,w=1,t; k<j+mid; ++k,w=1ll*w*Wn%mod)
                a[k+mid]=a[k]-(t=1ll*a[k+mid]*w%mod)+mod, Mod(a[k+mid]),
                a[k]+=t, Mod(a[k]);
    }
    if(opt==-1) for(int i=0,inv=::inv; i<lim; ++i) a[i]=1ll*a[i]*inv%mod;
}
void Mult1(int *f,int n,int lim)//x*x可以少一次NTT啊 
{
    NTT(f,lim,1);
    for(int i=0; i<lim; ++i) f[i]=1ll*f[i]*f[i]%mod;
    NTT(f,lim,-1);
    for(int i=0; i<n; ++i) f[i]=f[i]+f[i+n], Mod(f[i]);//f[i+m-1]
    for(int i=n; i<lim; ++i) f[i]=0;//!
}
void Mult2(int *a,int *b,int *res,int n,int lim)
{
    static int f[N],g[N];
    memset(f,0,sizeof f), memset(g,0,sizeof g);
    memcpy(f,a,n<<2), memcpy(g,b,n<<2);
    NTT(f,lim,1), NTT(g,lim,1);
    for(int i=0; i<lim; ++i) f[i]=1ll*f[i]*g[i]%mod;
    NTT(f,lim,-1);
    for(int i=0; i<n; ++i) res[i]=f[i]+f[i+n], Mod(res[i]);//f[i+m-1] //要保證res的n項以外為0 
}

int main()
{
    static int x[N],res[N];
    int n=read(),m=read(),X=read(),S=read();
    Pre(m), --m;//0~m-1

    int lim=1,l=-1;
    while(lim<=m+m-2) lim<<=1,++l;
    for(int i=1; i<lim; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
    inv=FP(lim,mod-2,mod);

    for(int i=1,s; i<=S; ++i) s=read(), s&&(++x[I[s]]);//++f[1][s]
    res[I[1]]=1;//f[0][1]=1
    for(int k=n; k; k>>=1,Mult1(x,m,lim))
        if(k&1) Mult2(res,x,res,m,lim);
    printf("%d\n",res[I[X]]);

    return 0;
}