1. 程式人生 > >【LOJ】#2320. 「清華集訓 2017」生成樹計數

【LOJ】#2320. 「清華集訓 2017」生成樹計數

rac res 然而 除了 加法 wap OS 代碼 reg

題解

我,理解題解,用了一天
我,卡常數,又用了一天
到了最後,我才發現,我有個加法取模,寫的是while(c >= MOD) c -= MOD
我把while改成if,時間,少了
六倍。
六倍。
六倍!!!!

maya我又用第一次T的代碼改掉了while,我第一次T的代碼也A了= =

那我,改單位復根,FFT循環展開,分治內部循環展開,為了啥= =

好吧,但是我最後上榜了。。。LOJ第四的樣子。。

\(\prod_{i = 1}^{N} d_{i}^{m}\sum_{i = 1}^{N}d_{i}^{m}\)
每個聯通塊的連出去的邊,有\(a_{i}\)種可能,所以式子是這樣的
\(\prod_{i = 1}^{N} a_{i}^{d_{i}}d_{i}^{m}\sum_{i = 1}^{N}d_{i}^{m}\)


\(\sum_{i = 1}^{N} d_{i}^{m}\prod_{j = 1}^{N}a_{j}^{d_{j}}d_{j}^{m}\)
我們考慮一下一個最暴力的dp(我的第一反應,啥,這怎麽是dp?

當然是dp出一個prufer序列啦

\(f[i][j]\)表示考慮了前i個數,填了j個格子,沒有多乘任何一個\(d_{i}^{m}\)
\(g[i][j]\)表示考慮了前i個數,填了j個格子,已經乘了一個\(d_{i}^{m}\)

dp的時候,我們只需要枚舉下一個數填了多少個格子就好
復雜度\(O(n^{3})\)

但是顯然過不掉

我們再考慮一個……套路!因為m出奇的小?
乘方轉斯特林數
想一下\(x^{m}\)

的組合意義,相當於m種顏色塗在x個格子裏,每個種顏色只能塗一個格子,但是每個格子可以塗很多顏色

再看一下原式子,把它變成這樣的形式
\(\sum_{i = 1}^{N} a_{i}^{d_{i}}d_{i}^{2m}\prod_{j = 1,j != i}^{N}a_{j}^{d_{j}}d_{j}^{m}\)
然後,我們把乘方拆開,怎麽拆呢
考慮到上述的組合意義
我們也就是求所有prufer序列的染色方式,在每一次決策的時候將某個點用2m種顏色染,我們在序列後面(腦補)出1 - N,這樣prufer序列裏數字出現的個數就正好是點度

設已經被染色的格子是j,那麽再放入一個新的數,染色個數為k的時候,需要乘上
\(\binom{n - 2 - j}{k}((S(m,k)k! + S(m,k + 1)(k + 1)!)\)


為什麽還有k + 1,因為k代表的是這n - 2個prufer序裏面的這類點染色的個數,我們還有後面腦補出的1 - N的位置,所以要+1
同時還要乘上\(a^{k}\)因為同一個位置的點有a種選法

如果是2m種顏色染色,把m換成2m就可以

我們觀察一下這個組合數,我們會發現我們可以把\(\frac{1}{k!}\)分離出來,每次統計貢獻的時候加上,剩下的等全部算完答案後,如果染色的個數為\(j\),那麽答案再乘上\(\frac{(n - 2)!}{(n - 2 - j)!}\)
同時,我們除了染色的位置還有很多空位,每個空位都有\(\sum_{i = 1}^{n}a_{i}\)種情況,如果有\(j\)個位置被染色了,那麽答案還要再乘上
\((\sum_{i = 1}^{n}a_{i})^{n - 2 - j}\)
這樣的復雜度是\(n^2m\)

然而這個方程本質是個卷積,可以上分治FFT

他說不卡常……確實不卡常,我的錯誤神奇得太離譜了= =

正常模樣的代碼

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
//#define ivorysi
#define pb push_back
#define space putchar(‘ ‘)
#define enter putchar(‘\n‘)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define mo 974711
#define MAXN 30005
using namespace std;
typedef long long int64;
typedef double db;
template<class T>
void read(T &res) {
    res = 0;char c = getchar();T f = 1;
    while(c < ‘0‘ || c > ‘9‘) {
    if(c == ‘-‘) f = -1;
    c = getchar();
    }
    while(c >= ‘0‘ && c <= ‘9‘) {
    res = res * 10 + c - ‘0‘;
    c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x) {
    if(x < 0) {putchar(‘-‘);x = -x;}
    if(x >= 10) {
    out(x / 10);
    }
    putchar(‘0‘ + x % 10);
}
const int MOD = 998244353,G = 3,L = (1 << 20);
int fac[MAXN],S[75][75],invfac[MAXN],N,M,a[MAXN],P[MAXN][75],F1[MAXN][75],F2[MAXN][75];
int f[2][MAXN],g[2][MAXN],W[L + 5];
int mul(int a,int b) {return 1LL * a * b % MOD;}
int inc(int a,int b) {a = a + b;if(a >= MOD) a -= MOD;return a;}
void update(int &x,int y) {x = inc(x,y);}
int fpow(int x,int c) {
    int res = 1,t = x;
    while(c) {
        if(c & 1) res = mul(res,t);
        t = mul(t,t);
        c >>= 1;
    }
    return res;
}
struct poly {
    vector<int> a;
    poly() {a.clear();}
    friend void NTT(poly &f,int T,int on) {
        f.a.resize(T);
        for(int i = 1 , j = T / 2; i < T - 1; ++i) {
            if(i < j) swap(f.a[i],f.a[j]);
            int k = T / 2;
            while(j >= k) {j -= k;k >>= 1;}
            j += k;
        }
        for(int h = 2 ; h <= T ; h <<= 1) {
            int wn = W[(L + on * L / h) % L];
            for(int k = 0 ; k < T ; k += h) {
                int w = 1;
                for(int j = k ; j < k + h / 2 ; ++j) {
                    int u = f.a[j],t = mul(f.a[j + h / 2],w);
                    f.a[j] = inc(u,t);
                    f.a[j + h / 2] = inc(u,MOD - t);
                    w = mul(w,wn);
                }
            }
        }
        if(on == -1) {
            int InvT = fpow(T,MOD - 2);
            for(int i = 0 ; i < T ; ++i) f.a[i] = mul(f.a[i],InvT);
        }
    }
    friend poly operator + (const poly &f,const poly &g) {
        int T = max(f.a.size(),g.a.size());
        poly h;h.a.resize(T);
        for(int i = 0 ; i < T ; ++i) h.a[i] = inc(f.a[i],g.a[i]);
        return h;
    }
    friend poly operator * (poly f,poly g) {
        int T = 1,t = f.a.size() + g.a.size();
        while(T <= t) T <<= 1;
        NTT(f,T,1);NTT(g,T,1);
        poly h;h.a.resize(T);
        for(int i = 0 ; i < T ; ++i) h.a[i] = mul(f.a[i],g.a[i]);
        NTT(h,T,-1);
        for(int i = T - 1 ; i >= 0 ; --i) {
            if(!h.a[i]) h.a.pop_back();
            else break;
        }
        if(h.a.size() > N - 1) h.a.resize(N - 1);
        return h;
    }
};
void Init() {
    read(N);read(M);
    int T = max(N,2 * M + 1);
    fac[0] = 1;
    for(int i = 1 ; i <= T ; ++i) fac[i] = mul(fac[i - 1],i);
    invfac[T] = fpow(fac[T],MOD - 2);
    for(int i = T - 1; i >= 0 ; --i) invfac[i] = mul(invfac[i + 1],i + 1);
    S[0][0] = 1;
    for(int i = 1 ; i <= 70 ; ++i) {
        for(int j = 1 ; j <= i ; ++j) {
            S[i][j] = inc(S[i - 1][j - 1],mul(S[i - 1][j],j));
        }
    }
    
    for(int i = 1 ; i <= N ; ++i) read(a[i]);
    for(int i = 1 ; i <= N ; ++i) {
        P[i][0] = 1;
        for(int j = 1 ; j <= 70 ; ++j) {
            P[i][j] = mul(P[i][j - 1],a[i]);
        }
    }
    W[0] = 1,W[1] = fpow(G,(MOD - 1) / L);
    for(int i = 2 ; i < L ; ++i) W[i] = mul(W[i - 1],W[1]);
}
pair<poly,poly> Solve(int L,int R) {
    if(L == R) {
        poly s,t;s.a.resize(M + 1);t.a.resize(2 * M + 1);
        for(int i = 0 ; i <= 2 * M ; ++i) {
            if(i <= M) 
                s.a[i] = mul(mul(inc(mul(S[M][i],fac[i]),mul(S[M][i + 1],fac[i + 1])),invfac[i]),P[L][i]);
            t.a[i] = mul(mul(inc(mul(S[2 * M][i],fac[i]),mul(S[2 * M][i + 1],fac[i + 1])),invfac[i]),P[L][i]);
        }
        return mp(s,t);
    }
    int mid = (L + R) >> 1;
    pair<poly,poly> S = Solve(L,mid),T = Solve(mid + 1,R);
    return mp(S.fi * T.fi,S.se * T.fi + S.fi * T.se);
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Init();
    pair<poly,poly> res = Solve(1,N);
    poly g = res.se;g.a.resize(N - 1);
    int sum = 0,p = 1,ans = 0,t = 1;
    for(int i = 1 ; i <= N ; ++i) sum = inc(sum,a[i]),p = mul(p,a[i]);
    for(int i = N - 2 ; i >= 0; --i) {
        ans = inc(ans,mul(mul(g.a[i],t),invfac[N - 2 - i]));
        t = mul(t,sum);
    }
    ans = mul(ans,mul(fac[N - 2],p));
    out(ans);putchar(\n);
    return 0;
}

卡常之後的代碼

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
//#define ivorysi
#define pb push_back
#define space putchar(‘ ‘)
#define enter putchar(‘\n‘)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define mo 974711
#define MAXN 400005
#define RG register
using namespace std;
typedef long long int64;
typedef double db;
template<class T>
void read(T &res) {
    res = 0;char c = getchar();T f = 1;
    while(c < ‘0‘ || c > ‘9‘) {
    if(c == ‘-‘) f = -1;
    c = getchar();
    }
    while(c >= ‘0‘ && c <= ‘9‘) {
    res = res * 10 + c - ‘0‘;
    c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x) {
    if(x < 0) {putchar(‘-‘);x = -x;}
    if(x >= 10) {
    out(x / 10);
    }
    putchar(‘0‘ + x % 10);
}
const int MOD = 998244353,L = (1 << 16);
int fac[MAXN],S[105][105],invfac[MAXN],a[MAXN],N,M,P[MAXN][105];
int W[L + 5],IW[L + 5],top,fl[MAXN],fr[MAXN],gl[MAXN],gr[MAXN],f[MAXN],g[MAXN];
vector<int> F[105],G[105];
inline int mul(RG const int &a,RG const int &b) {return (int64)a * b % MOD;}
inline int inc(RG const int &a,RG const int &b) {RG int c = a + b;if(c >= MOD) c -= MOD;return c;}
int fpow(RG int x,RG int c) {
    RG int res = 1,t = x;
    while(c) {
    if(c & 1) res = mul(res,t);
    t = mul(t,t);
    c >>= 1;
    }
    return res;
}
void NTT(RG int *f,RG int T,RG int on,RG int *w) {
    RG int tmp,*wm,*ai,*ami;
    for(RG int i = 1 , j = T / 2; i < T - 1 ; ++i) {
    if(i < j) {tmp = f[i];f[i] = f[j];f[j] = tmp;}
    RG int k = T / 2;
    while(j >= k) {
        j -= k;
        k >>= 1;
    }
    j += k;
    }
    
    #define work(j) {tmp = mul(ami[j],wm[j]);ami[j] = inc(ai[j],MOD - tmp);ai[j] = inc(ai[j],tmp); }
    for(RG int h = 2,m = 1; h <= T ; m = h,h <<= 1) {
    wm = w + m;
    if(m < 8) {
        for(RG int i = 0 ; i < T ; i += h) {
        ai = f + i,ami = f + m + i;
        for(RG int j = 0 ; j < m ; ++j) work(j);
        }
    }
    else {
        for(RG int i = 0 ; i < T ; i += h) {
        ai = f + i,ami = f + m + i;
        for(RG int j = 0 ; j < m ; j += 8) {
            work(j);
            work(j + 1);
            work(j + 2);
            work(j + 3);
            work(j + 4);
            work(j + 5);
            work(j + 6);
            work(j + 7);
        }
        }
    }
    }
    if(on < 0) {
    RG int InvT = fpow(T,MOD - 2);
        #define C(x,y) {f[x] = mul(f[x],y);}
    if(T < 8) {for(RG int i = 0 ; i < T ; ++i) C(i,InvT);}
    else {
        for(RG int i = 0 ; i < T ; i += 8) {
        C(i,InvT);
        C(i + 1,InvT);
        C(i + 2,InvT);
        C(i + 3,InvT);
        C(i + 4,InvT);
        C(i + 5,InvT);
        C(i + 6,InvT);
        C(i + 7,InvT);
        }
    }
    }
}
void Init() {
    read(N);read(M);
    RG int T = max(N,2 * M + 1);
    fac[0] = 1;
    for(RG int i = 1 ; i <= T ; ++i) fac[i] = mul(fac[i - 1],i);
    invfac[T] = fpow(fac[T],MOD - 2);
    for(RG int i = T - 1; i >= 0 ; --i) invfac[i] = mul(invfac[i + 1],i + 1);
    S[0][0] = 1;
    for(RG int i = 1 ; i <= 70 ; ++i) {
    for(RG int j = 1 ; j <= i ; ++j) {
        S[i][j] = inc(S[i - 1][j - 1],mul(S[i - 1][j],j));
    }
    }
    for(RG int j = 1 ; j <= 2 * M ; ++j) {
        S[M][j] = mul(S[M][j],fac[j]);
        S[2 * M][j] = mul(S[2 * M][j],fac[j]); 
    }
    for(RG int j = 0 ; j <= 2 * M ; ++j) {
        S[M][j] = mul(inc(S[M][j],S[M][j + 1]),invfac[j]);
        S[2 * M][j] = mul(inc(S[2 * M][j],S[2 * M][j + 1]),invfac[j]);
    }
    for(RG int i = 1 ; i <= N ; ++i) read(a[i]);
    for(RG int i = 1 ; i <= N ; ++i) {
    P[i][0] = 1;
    for(RG int j = 1 ; j <= 70 ; ++j) {
        P[i][j] = mul(P[i][j - 1],a[i]);
    }
    }
    RG int half = L / 2;
    RG int t1 = fpow(3,(MOD - 1) / L),t2 = fpow(t1,MOD - 2);
    W[half] = 1;IW[half] = 1;
    for(RG int i = 1 ; i < half; ++i) W[i + half] = mul(W[i + half - 1],t1),IW[i + half] = mul(IW[i + half - 1],t2);
    for(RG int i = half - 1 ; i >= 0 ; --i) W[i] = W[i << 1],IW[i] = IW[i << 1];
}
void Solve(RG int L,RG int R) {
    
    if(L == R) {
    ++top;
    F[top].resize(M + 1);G[top].resize(2 * M + 1);
    for(RG int i = 0 ; i <= 2 * M ; ++i) {
        if(i <= M) 
        F[top][i] = mul(S[M][i],P[L][i]);
        G[top][i] = mul(S[2 * M][i],P[L][i]);
    }
    return ;
    }
    RG int mid = (L + R) >> 1;
    Solve(L,mid);int Ld = top; 
    Solve(mid + 1,R);int Rd = top;
    top -= 2;
    RG int s1 = F[Ld].size(),s2 = F[Rd].size(),s3 = G[Ld].size(),s4 = G[Rd].size();
    RG int t = max(max(s1 + s2,s1 + s4),s3 + s2);
    RG int K = 1;while(K <= t) K <<= 1;
    for(int i = 0 ; i < s1 ; ++i) fl[i] = F[Ld][i];
    for(int i = 0 ; i < s2 ; ++i) fr[i] = F[Rd][i];
    for(int i = 0 ; i < s3 ; ++i) gl[i] = G[Ld][i];
    for(int i = 0 ; i < s4 ; ++i) gr[i] = G[Rd][i];
    fill(fl + s1,fl + K,0);
    fill(fr + s2,fr + K,0);
    fill(gl + s3,gl + K,0);
    fill(gr + s4,gr + K,0);
    NTT(fl,K,1,W);NTT(fr,K,1,W);NTT(gl,K,1,W);NTT(gr,K,1,W);
#define Calc1(i) {f[i] = mul(fl[i],fr[i]);}
#define Calc2(i) {g[i] = inc(mul(fl[i],gr[i]),mul(gl[i],fr[i]));}
    if(K < 8) {
    for(int i = 0 ; i < K ; ++i) {
        Calc1(i);Calc2(i);
    }
    }
    else {
    for(RG int i = 0 ; i < K ; i += 8) {
        Calc1(i);Calc1(i + 1);
        Calc1(i + 2);Calc1(i + 3);
        Calc1(i + 4);Calc1(i + 5);
        Calc1(i + 6);Calc1(i + 7);
        Calc2(i);Calc2(i + 1);
        Calc2(i + 2);Calc2(i + 3);
        Calc2(i + 4);Calc2(i + 5);
        Calc2(i + 6);Calc2(i + 7);
    }
    }
    NTT(f,K,-1,IW);NTT(g,K,-1,IW);
    ++top;
    F[top].clear();G[top].clear();
    t = min(N - 2,K - 1);
    while(t >= 0) {
    if(!f[t]) --t;
    else break;
    }
    for(RG int i = 0 ; i <= t ; ++i) F[top].pb(f[i]);
    t = min(N - 2,K - 1);
    while(t >= 0) {
    if(!g[t]) --t;
    else break;
    }
    for(RG int i = 0 ; i <= t ; ++i) G[top].pb(g[i]);
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Init();
    Solve(1,N);
    G[top].resize(N - 1);
    RG int ans = 0,sum = 0,p = 1,t = 1;
    for(RG int i = 1 ; i <= N ; ++i) sum = inc(sum,a[i]),p = mul(p,a[i]);
    for(RG int i = N - 2 ; i >= 0 ; --i) {
    ans = inc(ans,mul(mul(G[top][i],t),invfac[N - 2 - i]));
    t = mul(t,sum);
    }
    ans = mul(mul(ans,p),fac[N - 2]);
    out(ans);enter;
    //out(clock());enter;
    return 0;
}

【LOJ】#2320. 「清華集訓 2017」生成樹計數