1. 程式人生 > >Educational Codeforces Round 2 E - Lomsat gelral

Educational Codeforces Round 2 E - Lomsat gelral

題意:每個節點有個值,求每個節點子樹眾數和
題解:可線段樹合併,維護每個數出現次數和最大出現次數,以及最大出現次數的數的和

//#pragma GCC optimize(2)
//#pragma GCC optimize(3)
//#pragma GCC optimize(4)
//#pragma GCC optimize("unroll-loops")
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>
#define fi first
#define se second
#define db double
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define ld long double
//#define C 0.5772156649
//#define ls l,m,rt<<1
//#define rs m+1,r,rt<<1|1
#define pll pair<ll,ll>
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
//#define cd complex<double>
#define ull unsigned long long
//#define base 1000000000000000000
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
#define fin freopen("a.txt","r",stdin)
#define fout freopen("a.txt","w",stdout)
#define fio ios::sync_with_stdio(false);cin.tie(0)
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void sub(ll &a,ll b){a-=b;if(a<0)a+=mod;}
inline void add(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
template<typename T>inline T const& MAX(T const &a,T const &b){return a>b?a:b;}
template<typename T>inline T const& MIN(T const &a,T const &b){return a<b?a:b;}
inline ll qp(ll a,ll b){ll ans=1;while(b){if(b&1)ans=ans*a%mod;a=a*a%mod,b>>=1;}return ans;}
inline ll qp(ll a,ll b,ll c){ll ans=1;while(b){if(b&1)ans=ans*a%c;a=a*a%c,b>>=1;}return ans;}

using namespace std;

const double eps=1e-8;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int N=100000+10,maxn=50000+10,inf=0x3f3f3f3f;

vi v[N];
int root[N*22];
int ls[N*22],rs[N*22],tot,n;
ll ans[N],sum[N*22],num[N*22],ma[N*22];
void pushup(int o)
{
    ma[o]=max(ma[ls[o]],ma[rs[o]]);
    if(ma[ls[o]]==ma[rs[o]])sum[o]=sum[ls[o]]+sum[rs[o]];
    else if(ma[ls[o]]>ma[rs[o]])sum[o]=sum[ls[o]];
    else sum[o]=sum[rs[o]];
}
inline int Merge(int x,int y,int l,int r)
{
    if(l==r)
    {
        if(!x||!y)
        {
            ma[x+y]=num[x+y]=num[x]+num[y];
            sum[x+y]=l;
            return x+y;
        }
        else
        {
            ma[x]=num[x]=num[x]+num[y];
            sum[x]=l;
            return x;
        }
    }
    if(!x||!y)return x+y;
    int m=(l+r)>>1;
    ls[x]=Merge(ls[x],ls[y],l,m);
    rs[x]=Merge(rs[x],rs[y],m+1,r);
    pushup(x);
    return x;
}
void build(int &o,int pos,int l,int r)
{
    if(!o)o=++tot;
    if(l==r)
    {
        ma[o]=num[o]=1;
        sum[o]=l;
        return ;
    }
    int m=(l+r)>>1;
    if(pos<=m)build(ls[o],pos,l,m);
    else build(rs[o],pos,m+1,r);
    pushup(o);
}
void debug(int o,int l,int r)
{
    printf("%d+++%d %d %d %d %d\n",o,sum[o],num[o],ma[o],l,r);
    if(l==r)return ;
    int m=(l+r)>>1;
    if(ls[o])debug(ls[o],l,m);
    if(rs[o])debug(rs[o],m+1,r);
}
void dfs(int u,int f)
{
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f)dfs(x,u);
    }
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f)root[u]=Merge(root[u],root[x],1,n);
    }
    ans[u]=sum[root[u]];
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        int x;scanf("%d",&x);
        build(root[i],x,1,n);
    }
    for(int i=1;i<n;i++)
    {
        int a,b;scanf("%d%d",&a,&b);
        v[a].pb(b),v[b].pb(a);
    }
    dfs(1,-1);
    for(int i=1;i<=n;i++)printf("%lld ",ans[i]);puts("");
    return 0;
}
/********************
4
1 2 2 4
1 2
2 3
2 4
********************/

也可dsu on tree,先輕重鏈剖分,每次遞迴時保留重兒子的子樹資訊,統計答案時,先遞迴輕兒子統計子樹資訊,然後統計完後刪除資訊,維護每個次數的和

//#pragma GCC optimize(2)
//#pragma GCC optimize(3)
//#pragma GCC optimize(4)
//#pragma GCC optimize("unroll-loops")
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>
#define fi first
#define se second
#define db double
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define ld long double
//#define C 0.5772156649
//#define ls l,m,rt<<1
//#define rs m+1,r,rt<<1|1
#define pll pair<ll,ll>
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
//#define cd complex<double>
#define ull unsigned long long
//#define base 1000000000000000000
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
#define fin freopen("a.txt","r",stdin)
#define fout freopen("a.txt","w",stdout)
#define fio ios::sync_with_stdio(false);cin.tie(0)
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void sub(ll &a,ll b){a-=b;if(a<0)a+=mod;}
inline void add(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
template<typename T>inline T const& MAX(T const &a,T const &b){return a>b?a:b;}
template<typename T>inline T const& MIN(T const &a,T const &b){return a<b?a:b;}
inline ll qp(ll a,ll b){ll ans=1;while(b){if(b&1)ans=ans*a%mod;a=a*a%mod,b>>=1;}return ans;}
inline ll qp(ll a,ll b,ll c){ll ans=1;while(b){if(b&1)ans=ans*a%c;a=a*a%c,b>>=1;}return ans;}

using namespace std;

const double eps=1e-8;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int N=100000+10,maxn=50000+10,inf=0x3f3f3f3f;

vi v[N];
int c[N],sz[N],l[N],r[N],id[N],cnt,n;
ll num[N],sum[N],maxx,ans[N];
void dfs(int u,int f)
{
    sz[u]=1;
    l[u]=++cnt;id[cnt]=u;
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f)
        {
            dfs(x,u);
            sz[u]+=sz[x];
        }
    }
    r[u]=cnt;
}
void solve(int u,int f,bool keep)
{
    int ma=-1,son=-1;
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f&&ma<sz[x])ma=sz[x],son=x;
    }
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f&&x!=son)solve(x,u,0);
    }
    if(son!=-1)solve(son,u,1);
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f&&x!=son)for(int j=l[x];j<=r[x];j++)
        {
            sum[num[c[id[j]]]]-=c[id[j]];
            num[c[id[j]]]++;
            sum[num[c[id[j]]]]+=c[id[j]];
            maxx=max(maxx,num[c[id[j]]]);
        }
    }
    sum[num[c[u]]]-=c[u];
    num[c[u]]++;
    sum[num[c[u]]]+=c[u];
    maxx=max(maxx,num[c[u]]);
    ans[u]=sum[maxx];
    if(!keep)for(int i=l[u];i<=r[u];i++)
    {
        sum[num[c[id[i]]]]-=c[id[i]];
        num[c[id[i]]]--;
        sum[num[c[id[i]]]]+=c[id[i]];
        if(sum[maxx]==0)maxx--;
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&c[i]);
        sum[0]+=i;
    }
    for(int i=1;i<n;i++)
    {
        int a,b;scanf("%d%d",&a,&b);
        v[a].pb(b),v[b].pb(a);
    }
    dfs(1,-1);
    solve(1,-1,0);
    for(int i=1;i<=n;i++)printf("%lld ",ans[i]);puts("");
    return 0;
}
/********************
3
1 2 3
1 2
1 3
********************/