1. 程式人生 > >LOJ #2537. 「PKUWC 2018」Minimax (線段樹合併 優化dp)

LOJ #2537. 「PKUWC 2018」Minimax (線段樹合併 優化dp)

題意

\(C\) 有一棵 \(n\) 個結點的有根樹,根是 \(1\) 號結點,且每個結點最多有兩個子結點。

定義結點 \(x\) 的權值為:

1.若 \(x\) 沒有子結點,那麼它的權值會在輸入裡給出,保證這類點中每個結點的權值互不相同

2.若 \(x\) 有子結點,那麼它的權值有 \(p_x\) 的概率是它的子結點的權值的最大值,有 \(1-p_x\) 的概率是它的子結點的權值的最小值。

現在小 \(C\) 想知道,假設 \(1\) 號結點的權值有 \(m\) 種可能性,權值第 \(i\)的可能性的權值是 \(V_i\) ,它的概率為 \(D_i(D_i>0)\) ,求:

\[\displaystyle \sum _{i=1} ^ {m} i \cdot V_i \cdot D_i^2\]

你需要輸出答案對 \(998244353\) 取模的值。

對於 \(40\%\) 的資料,有 \(1\leq n\leq 5000\)

對於 \(100\%\) 的資料,有 \(1\leq n\leq 3\times 10^5, 1\leq w_i\leq 10^9\)

題解

首先考慮 \(O(n^2)\)dp , 令 \(dp_{u,i}\)\(u\) 號點 , 取到排名為 \(i\) 權值的概率 .

這個應該比較容易轉移 , 考慮列舉一個兒子取的值 , 然後對於它的貢獻 就分為它最小和它最大的兩種去計數就行了 .

然後這個用個 前\字尾和 優化就能達到 \(O(n^2)\) 複雜度了 (程式碼在最後)

\(ch[u][0/1]\)\(u\) 的左/右兒子 , 方程就是 : \[\displaystyle dp_{u,i} = \sum_{son=0}^{1} dp[ch[u][son]] \times Pre[i-1][son \oplus 1] \times p[u] + dp[ch[u][son]] \times Suf[i + 1][son \oplus 1](1 - p[u])\]

然後考慮優化 , 類似於這種狀態數與 \(size_u\) 有關的 dp .

常常可以考慮 線段樹合併 or 啟發式合併 來優化時間複雜度 .

一開始想直接 啟發式合併 線上段樹上操作 發現細節好多 而且不好維護 ... 然後就棄掉了 看了一波 LOJ 最短程式碼 qwq

誒 好像很好寫啊 , 原來直接線段樹合併就行了 . qwq

考慮維護一顆線段樹 , 每個點維護兩個值 \(sumv, mult\) 代表 區間和 以及 區間乘法的標記 .

然後每個葉子 代表一個 dp 值 , 然後每個區間就可以維護這段區間的 dp 值之和 .

我們一邊合併一邊算到當前區間 , 對於兩個線段樹 dp 值存在的貢獻 \(sumx, sumy\) (也就是前面方程中需要乘上後面的兩個東西) .

如果當前區間只有一個子樹 , 打下乘法標記 , 直接返回就行了 . 否則繼續遞迴下去合併解決 .

時間複雜度就是 $ O(\sum_{i=1}^{n} minsize) = O(n \log n)$ .

這是因為每個點合併上去 大小至少翻倍 . 意味著每個點最多被計算 \((\log n)\) 次 , 最後複雜度就是 \(O(n \log n)\) .

程式碼

\[40pts\]

#include <bits/stdc++.h>
#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
using namespace std;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar() ) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar() ) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
    freopen ("2537.in", "r", stdin);
    freopen ("2537.out", "w", stdout);
#endif
}

typedef long long ll;
const ll Mod = 998244353;
ll fpm(ll x, int power) {
    ll res = 1;
    for (; power; power >>= 1, (x *= x) %= Mod)
        if (power & 1) (res *= x) %= Mod;
    return res;
}

typedef long long ll;
const int N = 5100, inv = fpm(10000, Mod - 2);
int n, fa[N], ch[N][2], tot[N], val[N], rk[N], Leaf;
ll dp[N][N], p[N], Pre[N][2], Suf[N][2];

#define ls(o) ch[o][0]
#define rs(o) ch[o][1]
void Dp(int u) {
    if (!u) return ; Dp(ls(u)); Dp(rs(u));
    if (!tot[u]) { dp[u][rk[u]] = 1; }
    else if (tot[u] == 1) {
        For (i, 1, Leaf) dp[u][i] = dp[ls(u)][i];
    } else {
        For (son, 0, 1) {
            For (i, 1, Leaf)
                Pre[i][son] = (Pre[i - 1][son] + dp[ch[u][son]][i]) % Mod;
            Fordown (i, Leaf, 1)
                Suf[i][son] = (Suf[i + 1][son] + dp[ch[u][son]][i]) % Mod;
        }
        For (i, 1, Leaf) For (son, 0, 1) {
            (dp[u][i] += dp[ch[u][son]][i] * Pre[i - 1][son ^ 1] % Mod * p[u] % Mod 
             + dp[ch[u][son]][i] * Suf[i + 1][son ^ 1] % Mod * (Mod + 1 - p[u]) % Mod) %= Mod;
        }
    }
}

int main () {
    File();
    n = read();
    For (i, 1, n) fa[i] = read(), ch[fa[i]][tot[fa[i]] ++] = i;
    For (i, 1, n)
        if (!tot[i]) rk[i] = val[++ Leaf] = read();
        else p[i] = 1ll * read() * inv % Mod;

    sort(val + 1, val + Leaf + 1);
    For (i, 1, n) rk[i] = lower_bound(val + 1, val + Leaf + 1, rk[i]) - val;

    Dp(1);

    ll ans = 0;
    For (i, 1, n)
        (ans += 1ll * i * val[i] % Mod * dp[1][i] % Mod * dp[1][i] % Mod) %= Mod;
    printf ("%lld\n", ans);
    return 0;
}

\[100pts\]

#include <bits/stdc++.h>
#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
using namespace std;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar() ) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar() ) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
    freopen ("2537.in", "r", stdin);
    freopen ("2537.out", "w", stdout);
#endif
}

typedef long long ll;
const ll Mod = 998244353;
ll fpm(ll x, int power) {
    ll res = 1;
    for (; power; power >>= 1, (x *= x) %= Mod)
        if (power & 1) (res *= x) %= Mod;
    return res;
}

typedef long long ll;

const int N = 3e5 + 1e3, inv = fpm(10000, Mod - 2);
int val[N], rk[N];

#define ls(o) ch[o][0]
#define rs(o) ch[o][1]
#define lson ls(o), l, mid
#define rson rs(o), mid + 1, r
const int Maxnode = 6e6 + 1e3;
#define Mult(o, val) (sumv[o] *= (val)) %= Mod, (mult[o] *= (val)) %= Mod;
struct Segment_Tree {
    int rt[Maxnode], ch[Maxnode][2], Size; ll sumv[Maxnode], mult[Maxnode];

    inline void push_up(int o) { sumv[o] = (sumv[ls(o)] + sumv[rs(o)]) % Mod; }

    inline void push_down(int o) { 
        if (mult[o] <= 1) return ; Mult(ls(o), mult[o]); Mult(rs(o), mult[o]); mult[o] = 1;
    }

    void Update(int &o, int l, int r, int up, ll uv) {
        if (!o) o = (++ Size); mult[o] = 1;
        if (l == r) { (sumv[o] += uv) %= Mod; return ; } int mid = (l + r) >> 1;
        push_down(o); if (up <= mid) Update(lson, up, uv); else Update(rson, up, uv); push_up(o);
    }

    int Merge(int x, int y, ll sumx, ll sumy, ll probmax, ll probmin) {
        if (!y) { Mult(x, sumy); return x; }
        if (!x) { Mult(y, sumx); return y; }
        push_down(x); push_down(y);
        ll x0 = sumv[ls(x)], x1 = sumv[rs(x)], y0 = sumv[ls(y)], y1 = sumv[rs(y)];
        ls(x) = Merge(ls(x), ls(y), (sumx + probmin * x1) % Mod, (sumy + probmin * y1) % Mod, probmax, probmin);
        rs(x) = Merge(rs(x), rs(y), (sumx + probmax * x0) % Mod, (sumy + probmax * y0) % Mod, probmax, probmin);
        push_up(x); return x;
    }

    inline ll Calc(int o, int l, int r) {
        if (l == r) return 1ll * l * val[l] % Mod * sumv[o] % Mod * sumv[o] % Mod;
        int mid = (l + r) >> 1; push_down(o);
        return (Calc(lson) + Calc(rson)) % Mod;
    }
} T;

int n, fa[N], ch[N][2], tot[N], Leaf;
ll p[N];

void Dp(int u) {
    if (!u) return ; Dp(ls(u)); Dp(rs(u));
    if (!tot[u]) T.Update(T.rt[u], 1, Leaf, rk[u], 1);
    else if (tot[u] == 1) T.rt[u] = T.rt[ls(u)];
    else T.rt[u] = T.Merge(T.rt[ls(u)], T.rt[rs(u)], 0, 0, p[u], (Mod + 1 - p[u]) % Mod);
}

int main () {
    File();
    n = read();
    For (i, 1, n) fa[i] = read(), ch[fa[i]][tot[fa[i]] ++] = i;
    For (i, 1, n)
        if (!tot[i]) rk[i] = val[++ Leaf] = read();
        else p[i] = 1ll * read() * inv % Mod;

    sort(val + 1, val + Leaf + 1);
    For (i, 1, n) rk[i] = lower_bound(val + 1, val + Leaf + 1, rk[i]) - val;

    Dp(1);

    printf ("%lld\n", T.Calc(T.rt[1], 1, Leaf));
    return 0;
}