1. 程式人生 > >【BZOJ】2870: 最長道路tree-邊分治

【BZOJ】2870: 最長道路tree-邊分治

傳送門:bzoj2870


題解

邊分裸題
邊分治學習筆記-litble


程式碼

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int N=2e5+10,inf=0x7f7f7f7f;

int n,m,v[N],S,MN,rt,sz[N];
int head[N],to[N<<1],nxt[N<<1],w[N<<1],tot=1;
vector<int>g[N];
bool vs[N];ll ans; char cp; inline void rd(int &x) { cp=getchar();x=0;int f=0; for(;!isdigit(cp);cp=getchar()) if(cp=='-') f=1; for(;isdigit(cp);cp=getchar()) x=x*10+(cp^48); if(f) x=-x; } struct G{ int num; struct P{ int len,v; bool operator <(const P&ky)const{return v>ky.v;}
}a[N]; void dfs(int x,int fr,int dis,int mn) { int i,j;a[++num]=(P){dis,mn}; for(i=head[x];i;i=nxt[i]){ j=to[i];if(vs[i>>1] || j==fr) continue; dfs(j,x,dis+w[i],min(mn,v[j])); } } }A,B; inline void lk(int u,int v,int vv) { to[++tot]=v;nxt[tot]=head[u];head[u]=tot;w[tot]
=vv; to[++tot]=u;nxt[tot]=head[v];head[v]=tot;w[tot]=vv; } void dfs(int x,int fr) { int i,j; for(i=head[x];i;i=nxt[i]){ j=to[i];if(j==fr) continue; g[x].pb(j);dfs(j,x); } } void reb() { int x,i,j,sz,a,b; for(x=1;x<=n;++x){ sz=g[x].size(); if(sz<3){ for(i=0;i<sz;++i){j=g[x][i];lk(x,j,(j<=m));} }else{ a=++n;b=++n;v[a]=v[b]=v[x];lk(x,a,0);lk(x,b,0); for(i=0;i<sz;++i){j=g[x][i];(i&1)?g[b].pb(j):g[a].pb(j);} } } } void fdrt(int x,int fr) { sz[x]=1;int i,j,k; for(i=head[x];i;i=nxt[i]){ j=to[i];if(vs[i>>1] || j==fr) continue; fdrt(j,x);sz[x]+=sz[j]; k=max(sz[j],S-sz[j]); if(k<MN){MN=k;rt=i;} } } void sol(int le) { int i,j,la,lb,mx=0,rv,du=to[le^1],dv=to[le]; vs[le>>1]=true; A.num=B.num=0;A.dfs(du,0,0,v[du]);B.dfs(dv,0,0,v[dv]); la=A.num;lb=B.num; sort(A.a+1,A.a+la+1);sort(B.a+1,B.a+lb+1); for(i=j=1;i<=la;++i){ rv=A.a[i].v; for(;j<=lb && B.a[j].v>=rv;++j) mx=max(mx,B.a[j].len); if(j>1) ans=max(ans,(ll)rv*(mx+A.a[i].len+w[le]+1)); } mx=0; for(i=j=1;i<=lb;++i){ rv=B.a[i].v; for(;j<=la && A.a[j].v>=rv;++j) mx=max(mx,A.a[j].len); if(j>1) ans=max(ans,(ll)rv*(mx+B.a[i].len+w[le]+1)); } la=sz[dv];lb=S-sz[dv]; S=la;MN=N;fdrt(dv,0);if(MN<N) sol(rt); S=lb;MN=N;fdrt(du,0);if(MN<N) sol(rt); } int main(){ int i,x,y;rd(n);m=n; for(i=1;i<=n;++i) rd(v[i]); for(i=1;i<n;++i){rd(x);rd(y);lk(x,y,1);} dfs(1,0);memset(head,0,(n+1)<<2);tot=1;reb(); S=n;MN=N;fdrt(1,0);if(MN<N) sol(rt); printf("%lld",ans); return 0; }