CF161D Distance in Tree(點分治)
阿新 • • 發佈:2018-12-09
這是一道板子題
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; int u[50010<<1],v[50010<<1],fir[50010],nxt[50010<<1],cnt,root,sz[50010],f[50010],middis[50010],midcnt,n,k,vis[50010],Siz; long long ans=0; void addedge(int ui,int vi){ ++cnt; u[cnt]=ui; v[cnt]=vi; nxt[cnt]=fir[ui]; fir[ui]=cnt; } void getroot(int u,int fa){ sz[u]=1,f[u]=1; for(int i=fir[u];i;i=nxt[i]){ if(v[i]==fa||vis[v[i]]) continue; getroot(v[i],u); sz[u]+=sz[v[i]]; f[u]=max(f[u],sz[v[i]]); } f[u]=max(Siz-sz[u],f[u]); if(f[u]<f[root]) root=u; } void getdis(int u,int d,int fa){ // printf("ux=%d\n",u); middis[++midcnt]=d; for(int i=fir[u];i;i=nxt[i]){ if(vis[v[i]]||v[i]==fa) continue; getdis(v[i],d+1,u); } } int look1(int l,int k){ int ans=0,r=midcnt; while(l<=r){ int mid=(l+r)>>1; if(middis[mid]<k) l=mid+1; else ans=mid,r=mid-1; } return ans; } int look2(int l,int k){ int ans=0,r=midcnt; while(l<=r){ int mid=(l+r)>>1; if(middis[mid]<=k) l=mid+1,ans=mid; else r=mid-1; } return ans; } int solve(void){ sort(middis+1,middis+midcnt+1); // for(int i=1;i<=midcnt;i++) // printf("%d ",middis[i]); // getchar(); // printf("\n"); int mid=0; int l=1; while(l<midcnt&&middis[l]+middis[midcnt]<k) ++l; while(l<midcnt&&k-middis[l]>=middis[l]){ int l2=look2(l+1,k-middis[l]),l1=look1(l+1,k-middis[l]); if(l2>=l1) mid+=l2-l1+1; l++; } return mid; } void divide(int u){ // printf("u=%d\n",u); // getchar(); vis[u]=true; midcnt=0; getdis(u,0,0); // printf("ok\n"); ans+=solve(); // printf("an=%d\n",ans); for(int i=fir[u];i;i=nxt[i]){ if(vis[v[i]]) continue; midcnt=0; getdis(v[i],1,0); ans-=solve(); // printf("s=%d\n",ans); root=0; Siz=sz[v[i]]; getroot(v[i],u); divide(root); } } int main(){ scanf("%d %d",&n,&k); for(int i=1;i<=n-1;i++){ int a,b; scanf("%d %d",&a,&b); addedge(a,b); addedge(b,a); } Siz=n; f[0]=0x3f3f3f3f; getroot(1,0); divide(root); printf("%lld",ans); return 0; }