Divide by Zero 2018 and Codeforces Round #474 G. Bandit Blues DP+斯特林數+分治FFT
阿新 • • 發佈:2018-12-13
Description 給你三個正整數 n,a,b,定義A為一個排列中是字首最大值的數的個數,定義B為一個排列中是字尾最大值的數的個數,求長度為n的排列中滿足A = a且B = b的排列個數。
Sample Input 1 1 1
Sample Output 1
考慮DP,設f[i][j]為前i位有j個不同字首最大值方案數。 我們從大到小插數,對於當前這個數他只有放在第一位才可能有新的字首最大值,可得轉移: 這個玩意是第一類斯特林數。。。 然後我們考慮以n為分界點,其實就相當於有a+b-2個字首最大值,然後你選a-1個數放到左邊,b-1個數放到右邊,可得答案為: 對於第一類斯特林數的求法, f[n][m]就等於x的n次上升冪的第m項係數。 用分治FFT即可解決(漲姿勢)
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long LL; const LL mod = 998244353; int _min(int x, int y) {return x < y ? x : y;} int _max(int x, int y) {return x > y ? x : y;} int read() { int s = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();} while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar(); return s * f; } int R[410000]; LL A[410000], jc[410000]; LL pow_mod(LL a, LL k) { LL ans = 1; while(k) { if(k & 1) (ans *= a) %= mod; (a *= a) %= mod; k /= 2; } return ans; } void NTT(LL y[], int len, int on) { for(int i = 0; i < len; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) * (len >> 1)); for(int i = 0; i < len; i++) if(i < R[i]) swap(y[i], y[R[i]]); for(int i = 1; i < len; i *= 2) { LL wn = pow_mod(3, (LL)(mod - 1) / (i * 2)); if(on == -1) wn = pow_mod(wn, mod - 2); for(int j = 0; j < len; j += i * 2) { LL w = 1; for(int k = 0; k < i; k++) { LL u = y[j + k], v = y[j + k + i] * w % mod; y[j + k] = (u + v) % mod, y[j + k + i] = (u - v + mod) % mod; w = w * wn % mod; } } } if(on == -1) { LL tmp = pow_mod(len, mod - 2); for(int i = 0; i < len; i++) y[i] = y[i] * tmp % mod; } } void solve(LL *a, LL ln, int l, int r) { if(l == r) {a[0] = l - 1, a[1] = 1; return ;} int mid = (l + r) / 2; LL g1[ln + 10], g2[ln + 10]; memset(g1, 0, sizeof(g1)), memset(g2, 0, sizeof(g2)); solve(g1, ln / 2, l, mid), solve(g2, ln / 2, mid + 1, r); NTT(g1, ln, 1), NTT(g2, ln, 1); for(int i = 0; i < ln; i++) a[i] = g1[i] * g2[i] % mod; NTT(a, ln, -1); } LL C(int n, int m) { LL ans = 1; jc[0] = 1; for(int i = 1; i <= n; i++) jc[i] = jc[i - 1] * i % mod; return jc[n] * pow_mod(jc[m], mod - 2) % mod * pow_mod(jc[n - m], mod - 2) % mod; } int main() { int n = read(), a = read(), b = read(); if(a + b - 1 > n || !a || !b) {puts("0"); return 0;} if(n == 1) {puts("1"); return 0;} LL c = C(a + b - 2, a - 1); int ln; for(ln = 1; ln <= 2 * (n + 1); ln *= 2); solve(A, ln, 1, n - 1); printf("%lld\n", A[a + b - 2] * c % mod); return 0; }