1. 程式人生 > >luogu P2664 樹上游戲(點分治)

luogu P2664 樹上游戲(點分治)

點分治真是一個好東西。可惜我不會
這種要求所有路經的題很可能是點分治。
然後我就不會了。。
既然要用點分治,就想,點分治有哪些優點?它可以\(O(nlogn)\)遍歷分治樹的所有子樹。
那麼現在的問題就是,如可快速(\(O(n)\)或O\((nlogn)\))求以一個點為根的時候,子樹之間的貢獻(當然還有根節點的)。
我們注意到一件事,就是一棵子樹中一個點對其他子樹的點產生貢獻當且僅當這個點的顏色在它到根的路徑上第一次出現(或者說只算上這些貢獻答案正確),且貢獻為以這個點為根的子樹大小。(不考慮其它子樹的顏色)
這個有什麼用,我們可以遍歷兩遍子樹,第一遍預處理出所有子樹對其它子樹的貢獻(如上邊一段所說把貢獻統計),第二次遍歷每一顆子樹先把這顆樹的貢獻去掉,統計所有其它的樹對這顆樹的貢獻。
那麼具體該怎麼做?

void calc(int u){
    dfs1(u,0);
    ans[u]+=sum;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(vis[v])continue;
        cnt[a[u]]++;
        sum-=size[v];color[a[u]]-=size[v];
        change(v,u,-1);
        cnt[a[u]]--;
        tot=size[u]-size[v];
        dfs2(v,u);
        cnt[a[u]]++;
        sum+=size[v];color[a[u]]+=size[v];
        change(v,u,1);
        cnt[a[u]]--;
    }
    clear(u,0);
}

首先dfs1是統計貢獻的,用sum記錄貢獻和,color[i]記錄第i種顏色的貢獻。
然後根的答案就可以累加了。
那麼如可判斷一個顏色第一次出現?可以記錄一個cnt[i]記錄第i種顏色在到根的路徑上出現多少次。當cnt[i]等於1的時候統計貢獻。
然後

        cnt[a[u]]++;
        sum-=size[v];color[a[u]]-=size[v];
        change(v,u,-1);
        cnt[a[u]]--;

用來消除子樹貢獻。dfs2統計其它子樹對這顆子樹的貢獻。

void dfs2(int u,int f){
    cnt[a[u]]++;
    if(cnt[a[u]]==1){
        sum-=color[a[u]];
        num++;
    }
    ans[u]+=sum+num*tot;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        dfs2(v,u);
    }
    if(cnt[a[u]]==1){
        sum+=color[a[u]];
        num--;
    }
    cnt[a[u]]--;
}

如果這顆子樹中出現一個顏色,並且它是第一次出現,那麼減去所有子樹的color[a[u]],加上其它子樹的節點總數,因為每一條到其它子樹的路徑都會產生貢獻,這也是我們一開始不考慮貢獻對其他子樹影響的原因,因為遍歷子樹的時候會把這些重複的貢獻減去。
更具體還是看程式碼。

// luogu-judger-enable-o2
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define int long long
const int N=101000;
int Cnt,head[N];
int g[N],size[N],cnt[N],a[N],sum,color[N],tot,num,root,all,vis[N],ans[N],n;
struct edge{
    int to,nxt;
}e[N*2];
void add_edge(int u,int v){
    Cnt++;
    e[Cnt].nxt=head[u];
    e[Cnt].to=v;
    head[u]=Cnt;
}
int read(){
    int sum=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return sum*f;
}
void getroot(int u,int f){
    g[u]=0;size[u]=1;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        getroot(v,u);
        size[u]+=size[v];
        g[u]=max(g[u],size[v]);
    }
    g[u]=max(g[u],all-size[u]);
    if(g[u]<g[root])root=u;
}
void dfs1(int u,int f){
    cnt[a[u]]++;
    size[u]=1;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        dfs1(v,u);
        size[u]+=size[v];
    }
    if(cnt[a[u]]==1){
        sum+=size[u];
        color[a[u]]+=size[u];
    }
    cnt[a[u]]--;
}
void clear(int u,int f){
    cnt[a[u]]++;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        clear(v,u);
    }
    if(cnt[a[u]]==1){
        sum-=size[u];
        color[a[u]]-=size[u];
    }
    cnt[a[u]]--;
}
void dfs2(int u,int f){
    cnt[a[u]]++;
    if(cnt[a[u]]==1){
        sum-=color[a[u]];
        num++;
    }
    ans[u]+=sum+num*tot;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        dfs2(v,u);
    }
    if(cnt[a[u]]==1){
        sum+=color[a[u]];
        num--;
    }
    cnt[a[u]]--;
}
void change(int u,int f,int k){
    cnt[a[u]]++;
    if(cnt[a[u]]==1){
        sum+=k*size[u];color[a[u]]+=k*size[u];
    }
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f||vis[v])continue;
        change(v,u,k);
    }
    cnt[a[u]]--;
}
void calc(int u){
    dfs1(u,0);
    ans[u]+=sum;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(vis[v])continue;
        cnt[a[u]]++;
        sum-=size[v];color[a[u]]-=size[v];
        change(v,u,-1);
        cnt[a[u]]--;
        tot=size[u]-size[v];
        dfs2(v,u);
        cnt[a[u]]++;
        sum+=size[v];color[a[u]]+=size[v];
        change(v,u,1);
        cnt[a[u]]--;
    }
    clear(u,0);
}
void work(int u){
    calc(u);
    vis[u]=1;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(vis[v])continue;
        root=0,all=size[v];
        getroot(v,0);
        work(root);
    }
}
signed main(){
    n=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add_edge(u,v);add_edge(v,u);
    }
    g[0]=n+10;root=0;all=n;
    getroot(1,0);work(root);
    for(int i=1;i<=n;i++)printf("%lld\n",ans[i]);
    return 0;
}