1. 程式人生 > >bzoj 4675 點對遊戲 - 長鏈剖分

bzoj 4675 點對遊戲 - 長鏈剖分

題解:根據期望的線性性,可知答案是 ( n 2 k

2 ) ( n k )
1 u < v n
d i s t ( u , v ) M \frac{\binom{n-2}{k-2}}{\binom nk}\sum_{1\le u<v\le n}\mathrm{dist}(u,v)\in M 。後半部分直接長鏈剖分即可。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define N 50010
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
    int x,ch;while((ch=gc)<'0'||ch>'9');
    x=ch^'0';while((ch=gc)>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^'0');return x;
}
struct edges{
    int to,pre;
}e[N<<1];int h[N],etop,l[N],son[N];
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
int getl(int x,int fa=0)
{
    l[x]=0,son[x]=0;
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^fa)
        {
            getl(y,x),l[x]=max(l[x],l[y]+1);
            if(!son[x]||l[y]>l[son[x]]) son[x]=y;
        }
//  debug(x)sp,debug(l[x])sp,debug(son[x])ln;
    return 0;
}
int m,v[20];lint ans;
int dfs(int x,int *f,int fa=0)
{
    if(son[x]) dfs(son[x],f+1,x);else return f[0]=1;
    f[0]++;rep(i,1,m) if(v[i]<=l[x]) ans+=f[v[i]];
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)!=son[x]&&e[i].to!=fa)
        {
            int *fy=new int[l[y]+1];
            memset(fy,0,sizeof(int)*(l[y]+1)),dfs(y,fy,x);
            for(int j=0;j<=l[y];j++) for(int k=1;k<=m;k++)
                if(v[k]-j-1>=0&&v[k]-j-1<=l[x]) ans+=(lint)f[v[k]-j-1]*fy[j];
            for(int j=0;j<=l[y];j++) f[j+1]+=fy[j];
        }
//  debug(x)ln;rep(i,0,l[x]) debug(i)sp,debug(f[i])ln;cerr ln;
    return 0;
}
int main()
{
    int n=inn(),x,y;m=inn();rep(i,1,m) v[i]=inn();
    rep(i,1,n-1) x=inn(),y=inn(),add_edge(x,y),add_edge(y,x);
    int k=n/3,k1=k,k2=k,k3=k;if(n%3>=1) k1++;if(n%3==2) k2++;
    getl(1);int *f=new int[l[1]+1];
    memset(f,0,sizeof(int)*(l[1]+1)),dfs(1,f);
    printf("%.2lf\n",(double)(k1*(k1-(db)1)/n/(n-(db)1)*ans));
    printf("%.2lf\n",(double)(k2*(k2-(db)1)/n/(n-(db)1)*ans));
    printf("%.2lf\n",(double)(k3*(k3-(db)1)/n/(n-(db)1)*ans));
    return 0;
}