1. 程式人生 > >洛谷【P2664】樹上游戲

洛谷【P2664】樹上游戲

淺談樹分治:https://www.cnblogs.com/AKMer/p/10014803.html

題目傳送門:https://www.luogu.org/problemnew/show/P2664

對於所有求顏色種類數的問題,我們都可以定義一個方向,使得所有的顏色在最靠這個方向第一次出現的位置有效,而其它位置都是無效的。對於樹分治,我們可以定義這個方向為當前需要遍歷的子樹,反方向就是已經遍歷完的子樹。

對於一個點\(u\),如果從當前重心到他這一條路徑上,該點顏色是第一次出現,那麼它的顏色將給後面的遍歷帶來\(siz[u]\)的貢獻。另外,在遍歷當前子樹時,所有在重心到當前點這條路徑的上的顏色,貢獻都是已經遍歷過的子樹的總結點數。正過來做一遍,反過來做一遍就可以了。對於單獨的從重心到當前點的路徑會被統計兩次,所以要減掉一次。

邊分治重構樹之後不知道怎麼消除新結點的影響,如果有大佬願意教教我請在評論下方回覆。

這題資料貌似比較水,不卡不重構樹的邊分治。

時間複雜度:\(O(nlogn)\)

空間複雜度:\(O(n)\)

點分治版程式碼如下:

#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;

const int maxn=1e5+5;

bool vis[maxn];
ll ans[maxn],res;
int n,tot,mx,rt,N,Siz;
int now[maxn],pre[maxn<<1],son[maxn<<1];
int col[maxn],siz[maxn],cnt[maxn],V[maxn],sum[maxn];

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

void add(int a,int b) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b;
}

struct rubbish {
    bool bo[maxn];
    int sta[maxn],top;

    void clear() {
        Siz=res=0;
        while(top) {
            bo[sta[top]]=0;
            cnt[sta[top--]]=0;
        }
    }

    void ins(int id) {
        if(bo[id])return;
        bo[id]=1,sta[++top]=id;
    }
}R;

void find_rt(int fa,int u) {
    int res=0;siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)find_rt(u,v),siz[u]+=siz[v],res=max(res,siz[v]);
    res=max(res,N-siz[u]);
    if(res<mx)mx=res,rt=u;
}

void dfs(int fa,int u) {
    sum[col[u]]++,res+=(sum[col[u]]==1);
    ans[u]-=res,siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)dfs(u,v),siz[u]+=siz[v];
    sum[col[u]]--,res-=(sum[col[u]]==0);
}

void query(int fa,int u) {
    sum[col[u]]++;if(sum[col[u]]==1)res-=cnt[col[u]],res+=Siz+1;
    ans[u]+=res;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)query(u,v);
    sum[col[u]]--;if(sum[col[u]]==0)res+=cnt[col[u]],res-=Siz+1;
}

void solve(int fa,int u) {
    sum[col[u]]++;
    if(sum[col[u]]==1) {
        cnt[col[u]]+=siz[u];
        res+=siz[u];R.ins(col[u]);
    }
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)solve(u,v);
    sum[col[u]]--;
}

void print() {
    for(int i=1;i<=n;i++)
        printf("%lld ",ans[i]);
    puts("");
}

void work(int u,int size) {
    N=size,mx=rt=n+1,find_rt(0,u);
    u=rt,vis[u]=1,tot=0;
    sum[col[u]]++;res++;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v])V[++tot]=v,dfs(u,v);
    sum[col[u]]--;res--;
    for(int i=1;i<=tot;i++) {
        int v=V[i];
        sum[col[u]]++,res-=cnt[col[u]],res+=Siz+1;
        query(u,v);
        sum[col[u]]--,res+=cnt[col[u]],res-=Siz+1;
        solve(u,v),Siz+=siz[v];
    }R.clear();
    for(int i=tot;i;i--) {
        int v=V[i];
        sum[col[u]]++,res-=cnt[col[u]],res+=Siz+1;
        query(u,v);
        sum[col[u]]--,res+=cnt[col[u]],res-=Siz+1;
        solve(u,v),Siz+=siz[v];     
    }ans[u]+=res+Siz+1-cnt[col[u]];R.clear();
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v])work(v,siz[v]);
}

int main() {
    n=read();
    for(int i=1;i<=n;i++)
        col[i]=read();
    for(int i=1;i<n;i++) {
        int a=read(),b=read();
        add(a,b),add(b,a);
    }work(1,n);
    for(int i=1;i<=n;i++)
        printf("%lld\n",ans[i]);
    return 0;
}

不重構樹的邊分治版程式碼如下:

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

const int maxn=2e5+5;

bool vis[maxn];
ll ans[maxn],res;
int m,n,tot=1,mx,id,N;
int now[maxn],pre[maxn<<1],son[maxn<<1];
int col[maxn],siz[maxn],cnt[maxn],sum[maxn];

vector<int>to[maxn];
vector<int>::iterator it;

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

void add(int a,int b) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b;
}

struct Rubbish {
    bool bo[maxn];
    int sta[maxn],top;

    void clear() {
        res=0;
        while(top)cnt[sta[top]]=bo[sta[top]]=0,top--;
    }
    
    void ins(int id) {
        if(bo[id])return;
        bo[id]=1,sta[++top]=id;
    }
}R;

void find_edge(int fa,int u) {
    siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa) {
            find_edge(u,v),siz[u]+=siz[v];
            if(abs(N-2*siz[v])<mx)
                mx=abs(N-2*siz[v]),id=p>>1;
        }
}

void dfs(int fa,int u) {
    siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)dfs(u,v),siz[u]+=siz[v];
}

void solve(int fa,int u) {
    sum[col[u]]++;
    if(sum[col[u]]==1) {
        cnt[col[u]]+=siz[u];
        res+=siz[u],R.ins(col[u]);
    }
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)solve(u,v);
    sum[col[u]]--;
}

void query(int fa,int u,int num) {
    sum[col[u]]++;
    if(sum[col[u]]==1)res+=num,res-=cnt[col[u]];
    ans[u]+=res;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)query(u,v,num);
    sum[col[u]]--;
    if(sum[col[u]]==0)res-=num,res+=cnt[col[u]];
}

void work(int u,int size) {
    if(size<2)return;
    N=size,mx=id=m+1,find_edge(0,u),vis[id]=1;
    int u1=son[id<<1],u2=son[id<<1|1];
    dfs(0,u1),dfs(0,u2);
    solve(0,u1),query(0,u2,siz[u1]),R.clear();
    solve(0,u2),query(0,u1,siz[u2]),R.clear();
    work(u1,siz[u1]),work(u2,siz[u2]);
}

int main() {
    m=n=read();
    for(int i=1;i<=n;i++)
        col[i]=read();
    for(int i=1;i<n;i++) {
        int a=read(),b=read();
        add(a,b),add(b,a);
    }
    work(1,m);
    for(int i=1;i<=n;i++)printf("%lld\n",ans[i]+1);
    return 0;
}