1. 程式人生 > >hihocoder 1479 三等分 樹型dp

hihocoder 1479 三等分 樹型dp

描述

小Hi最近參加了一場比賽,這場比賽中小Hi被要求將一棵樹拆成3份,使得每一份中所有節點的權值和相等。

比賽結束後,小Hi發現雖然大家得到的樹幾乎一模一樣,但是每個人的方法都有所不同。於是小Hi希望知道,對於一棵給定的有根樹,在選取其中2個非根節點並將它們與它們的父親節點分開後,所形成的三棵子樹的節點權值之和能夠兩兩相等的方案有多少種。

兩種方案被看做不同的方案,當且僅當形成方案的2個節點不完全相同。

輸入

每個輸入檔案包含多組輸入,在輸入的第一行為一個整數T,表示資料的組數。

每組輸入的第一行為一個整數N,表示給出的這棵樹的節點數。

接下來N行,依次描述結點1~N,其中第i行為兩個整數Vi和Pi,分別描述這個節點的權值和其父親節點的編號。

父親節點編號為0的節點為這棵樹的根節點。

對於30%的資料,滿足3<=N<=100

對於100%的資料,滿足3<=N<=100000, |Vi|<=100, T<=10

輸出

對於每組輸入,輸出一行Ans,表示方案的數量。

樣例輸入
2
3
1 0
1 1
1 2
4
1 0
1 1
1 2
1 3
樣例輸出
1
0

統計所形成的三棵子樹的節點權值之和能夠兩兩相等的方案,等價於在這樹上取兩個不同且非根結點,形成三棵子樹後子樹節點權值之和兩兩相等。

樹型dp求解,每個節點維護res(以這個節點為根的子樹節點權值和),cnt(以這個節點為根的子樹權值等於sum/3的節點個數)。

一開始能夠想到的一個A節點res=sum/3,那麼在這棵子樹外再找一個res=sum/3的B節點進行組合不就得出一種方案了嗎。可是這裡面是分兩類的

1.      B不是A的祖先,那麼後來列舉B的時候A又被算了一次。記為2*s1

2.      B是A的祖先,其實這種情況是錯誤的,因為A、B分別取出後,A子樹res=sum/3,B子樹res=0(因為A子樹本來就是B的一部分啊),這種方案是錯誤的要除去,且記為p

還有一種情況是一個節點res=sum*2/3,那麼這個節點與其子樹內不包括它自己,任意一個res=sum/3的節點相組合就是一種方案,記為s2。

因為不能選root,所以cnt對res[root]=sum/3情況不予考慮

2*s1+p=cnt[root]^2 - (res[x]=sum/3&&x!=root)

P= (cnt[x]>0&&res[x]=sum/3&&x!=root)

S2= (res[x]=sum*2/3&&x!=root)

最後ans=s1+s2

Dp時維護好資料最後求解即可

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+8;
vector<int> g[maxn];
int root;
ll sum;
ll res[maxn];
ll cnt[maxn];
ll a[maxn];
ll s2,s3,sig;

void dfs(int x,int p){
    res[x]=a[x];
    cnt[x]=0;
    if(g[x].size()<=1){
        if(res[x]==sum)cnt[x]=1;
        sig+=cnt[x];
        //cout<<"xx="<<x<<" "<<sum<<" "<<res[x]<<" "<<cnt[x]<<endl;
        return ;
    }
    for(int i=0;i<g[x].size();i++){
        int u=g[x][i];
        if(u==p)continue;
        dfs(u,x);
        res[x]+=res[u];
        cnt[x]+=cnt[u];
    }
    if(res[x]==sum&&x!=root)cnt[x]+=1;
    if(res[x]==sum&&x!=root)sig+=cnt[x];
    if(cnt[x]>0&&res[x]==sum&&x!=root)s3+=(cnt[x]-1);
    if(res[x]==2*sum&&x!=root)s2+=(res[x]==sum?cnt[x]-1:cnt[x]);
}

int main()
{
    int T;
    scanf("%d",&T);
    while(T--){
        int n;
        scanf("%d",&n);
        for(int i=0;i<n+7;i++)g[i].clear();
        sum=0;
        for(int i=1;i<=n;i++){
            int v,p;
            scanf("%d%d",&v,&p);
            a[i]=v;
            g[i].push_back(p);
            g[p].push_back(i);
            if(p==0)root=i;
            sum+=v;
        }
        if(sum%3){printf("0\n");continue;}
        sum/=3;
        s2=s3=sig=0;
        dfs(root,0);
//        for(int i=1;i<=n;i++){
//            printf("id==%d res==%I64d cnt==%I64d\n",i,res[i],cnt[i]);
//        }
//        cout<<" s2=="<<s2<<" s3=="<<s3<<" sig=="<<sig<<endl;
        ll ans=((cnt[root]*cnt[root]-sig)-s3)/2+s2;
        cout<<ans<<endl;
    }
    return 0;
}