1. 程式人生 > >BZOJ 3653: 談笑風生(離線, 長鏈剖分, 字尾和)

BZOJ 3653: 談笑風生(離線, 長鏈剖分, 字尾和)

題意

給你一顆有 \(n\) 個點並且以 \(1\) 為根的樹。共有 \(q\) 次詢問,每次詢問兩個引數 \(p, k\) 。詢問有多少對點 \((p, a, b)\) 滿足 \(p,a,b\) 為三個不同的點,\(p, a\) 都為 \(b\) 的祖先,且 \(p\)\(a\) 的距離不能超過 \(k\)

\(n\le 300000 , q\le 300000\) 不要求強制線上。

題解

\(dep[u]\) 為點 \(u\) 的深度,\(sz[u]\)\(u\) 的子樹大小(除去 \(u\) 本身)

首先我們考慮兩種情況:

  1. \(a\)\(p\) 的祖先,那麼這部分貢獻很好計算,就是 \(\min\{dep[p] - 1,k\} \times sz[u]\)
  2. \(a\)\(p\) 的子樹內,那麼這部分貢獻就是 \(\displaystyle \sum_{dis(p,a) \le k} sz[a]\)

我們現在只要考慮第二部分貢獻怎麼求。

不難發現,這些點的深度就是 \([dep[p], dep[p]+k]\) 這個範圍內的。

那麼我們可以對於每個點用個 主席樹 來儲存這些資訊,可以線上回答詢問。

那麼離線的話,可以考慮用 線段樹合併 維護它每個子樹的資訊。

具體來說,這些都是對於每個 \(dep\) 維護它的 \(sz\) 的和,然後查區間和就行了。

然而這些時空複雜度都是 \(O(n \log n)\) ,其實還有更好的做法。

為什麼我發現了呢qwq?

因為 fatesky 做這道題線段樹合併做法的時候,Wearry 說可以 長鏈剖分 那就是 \(O(n)\) 的啦。

我們令 \(\displaystyle maxdep[u]=\max_{v \in child[u]} \{dep[v\}\) 也就是它子樹中的最大深度。

具體來說,長鏈剖分就是把每個點兒子中 \(maxdep\) 最大的那個當做重兒子。重兒子與父親連的邊叫做重邊。一連串重邊不間斷連到一起就叫做重鏈。

然後我們就有一條性質。

性質1 : 重鏈長度之和是 \(O(n)\) 的。

這個很顯然啦,因為總共只有 \(O(n)\) 級別的邊。

有了這個我們就可以解決一系列 關於深度的動態規劃

問題了,對於這列問題常常都可以做到 \(O(n)\) 的複雜度。

具體操作就是,每次暴力繼承重兒子的 \(dp\) 狀態,然後輕兒子暴力合併上去。

不難發現這個複雜度是 \(O(\sum\) 重鏈長 \()\) \(= O(n)\) 的。

繼承的時候常常需要移位,並且把當前節點貢獻算入,並且這個 \(dp\) 需要動態空間才能實現。

對於這道題我們考慮維護一個字尾和,也就是對於 \(u\) 子樹中的 \(v\)\(dep[v] \ge k\) 的所有 \(sz[v]\) 的和。

不難發現字尾和是很好合並的,這個的複雜度只需要 \(O(\min maxdep[v])\)

每次新增一個點 \(sz[u]\) 對於 \(dep[u]\) 的貢獻只會對一個點的貢獻產生影響,這個複雜度是 \(O(1)\) 的。

程式碼實現的話,就可以用一個 std :: vector ,按深度從大到小 ( \(maxdep[u] \to dep[u]\) )儲存每個點的資訊,因為這樣最方便繼承重兒子狀態(每次加入狀態只在整個 vector 末端新增一個元素)

其實可以動態開記憶體,順著做,但我似乎學不來

常數似乎有點大,沒比 \(O(n \log n)\) 快多少,vector 用多了... Wearry 到是優化了點常數到了 \(4000+ ms\)

話說這個很像原來 DOFY 講過的那道 ?

程式碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
    freopen ("3653.in", "r", stdin);
    freopen ("3653.out", "w", stdout);
#endif
}

const int N = 3e5 + 1e3;

struct Ask { int k, id; } ; vector<Ask> V[N];

vector<int> G[N]; int sz[N], maxdep[N], dep[N], sonmaxdep[N], son[N], rt[N];

vector<ll> sum[N]; int n, q; ll ans[N], Size = 0;

void Dfs_Init(int u, int fa = 0) {

    maxdep[u] = dep[u] = dep[fa] + 1;

    For (i, 0, G[u].size() - 1) {
        register int v = G[u][i];
        if (v ^ fa) Dfs_Init(v, u), chkmax(maxdep[u], maxdep[v]);
    }

}

void Dfs(int u, int fa = 0) {

    For (i, 0, G[u].size() - 1) {
        int v = G[u][i];
        if (v == fa) continue ;
        Dfs(v, u); sz[u] += sz[v];
        if (maxdep[v] > maxdep[son[u]]) son[u] = v;
    }
    rt[u] = rt[son[u]]; if (!rt[u]) rt[u] = ++ Size;

    int len = (int)sum[rt[u]].size();
    ll Last = len ? sum[rt[u]][len - 1] : 0;
    sum[rt[u]].push_back(Last);

    if (son[u]) {
        For (i, 0, G[u].size() - 1) {
            int v = G[u][i]; if (v == fa || v == son[u]) continue ;
            For (j, 0, sum[rt[v]].size() - 1) {
                int nowdep = (maxdep[son[u]] - maxdep[v]) + j;
                sum[rt[u]][nowdep] += sum[rt[v]][j];
            }
            sum[rt[u]][len] += sum[rt[v]][sum[rt[v]].size() - 1];
        }
    }

    For (i, 0, V[u].size() - 1) {
        Ask now = V[u][i];
        ans[now.id] = sum[rt[u]][len];
        if (len > now.k) ans[now.id] -= sum[rt[u]][len - now.k - 1];
        ans[now.id] += 1ll * min(dep[u] - 1, now.k) * sz[u];
    }

    sum[rt[u]][len] += sz[u]; ++ sz[u];

}

int main () {

    File();

    n = read(); q = read();

    For (i, 1, n - 1) {
        int u = read(), v = read();
        G[u].push_back(v);
        G[v].push_back(u);
    }

    For (i, 1, q) {
        int p = read(), k = read();
        V[p].push_back((Ask) {k, i});
    }

    Dfs_Init(1); Dfs(1);

    For (i, 1, q)
        printf ("%lld\n", ans[i]);

    return 0;
}