1. 程式人生 > >Codeforces Round #514 (Div. 2) E. Split the Tree (貪心 + 樹上倍增)

Codeforces Round #514 (Div. 2) E. Split the Tree (貪心 + 樹上倍增)

題目大意:給出一棵有 n 個結點的樹,每個結點都有一個權值 w ,現在要你將這棵樹分成若干條鏈,且每個結點只能屬於一條鏈,分出來的鏈滿足每條鏈上的結點不超過L個,同時這些結點的權值和不超過S。問你最少能把這棵樹分成幾條鏈。

題目思路:由於是要使得鏈儘可能的少,所以分出來的鏈每條鏈上的結點都是要儘可能的多的。

但如果從上往下去將樹分鏈的話操作起來會很麻煩,所以我們可以考慮至下往上去分鏈。

由於葉子結點之間是不會有交集的,所以每個葉子結點必然會分出一條鏈來。然後我們就可以考慮至下往上貪心的做法,由於要使得鏈上結點儘可能的多,所以我們可以通過樹上倍增的做法,預處理出每個結點從下往上形成的鏈最遠能到哪裡。

預處理出這個資訊之後,我們就可以對答案進行求解了,每次就取不屬於已取出鏈的一個結點,將其能形成的最長鏈變成一條新的鏈即可。

具體實現看程式碼:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int MX = 1e5 + 5;
const double eps = 1e-8;

int n, L, ans;
ll S, dis[MX];
vector<int>G[MX];
int val[MX], ST[MX][20], dep[MX], top[MX], head[MX];
//ST為倍增陣列;dis陣列表示從上往下的權值字首和,這樣求一條鏈的權值和只要減一下即可;
//dep為結點的深度;top表示該結點所能到達的最遠結點;head表示分出來的鏈最上端的結點;

void ST_init() {
    for (int i = 1; i < 20; i++) {
        for (int j = 1; j <= n; j++)
            ST[j][i] = ST[ST[j][i - 1]][i - 1];
    }
}

void dfs1(int u, int fa) {
    dis[u] = dis[fa] + val[u];
    dep[u] = dep[fa] + 1;
    top[u] = u;
    int cnt = L;
    for (int i = 19; i >= 0; i--) {
        int f = ST[top[u]][i];
        if ((1 << i) >= cnt || f == 0) continue;
        if (dis[u] - dis[ST[f][0]] > S) continue;
        cnt -= (1 << i);
        top[u] = f;
    }
    for (auto v : G[u]) dfs1(v, u);
}

void dfs2(int u) {
    int res = -1;
    for (auto v : G[u]) {
        dfs2(v);
        if (head[v] == v) continue;
        if (res == -1 || dep[res] > dep[head[v]]) res = head[v];
    }
    if (res == -1) {
        res = top[u];
        ans++;
    }
    head[u] = res;
}

int main() {
    // FIN;
    scanf("%d%d%lld", &n, &L, &S);
    bool flag = 0;
    for (int i = 1; i <= n; i++) {
        scanf("%d", &val[i]);
        if (val[i] > S) flag = 1;
    }
    if (flag) {
        puts("-1");
        return 0;
    }
    for (int i = 2, p; i <= n; i++) {
        scanf("%d", &p);
        G[p].pb(i);
        ST[i][0] = p;
    }
    ST_init();
    dfs1(1, 0); dfs2(1);
    cout << ans << endl;
    return 0;
}