1. 程式人生 > >樹上統計treecnt(dsu on tree 並查集 正難則反)

樹上統計treecnt(dsu on tree 並查集 正難則反)

problem freopen space script type 題目 每次 分割 pos

題目鏈接

\(Description\)

給定一棵\(n(n\leq 10^5)\)個點的樹。
定義\(Tree[L,R]\)表示為了使得\(L\sim R\)號點兩兩連通,最少需要選擇的邊的數量。
\[\sum_{l=1}^n\sum_{r=l}^nTree[l,r]\]

\(Solution\)

枚舉每條邊,計算它的貢獻。
那麽我們要判斷有多少連續區間的點跨過這條邊,並不好算,反過來去求在這條邊的兩側分別有多少個連續區間。
那麽顯然有\(O(n^2)\)的做法,即對每條邊DFS它的兩側,枚舉一下每一側的連續區間。
我們還可以DFS這棵樹,對於每個點我們需要計算它子樹內和子樹外的連續區間數。

對於子樹內的點,並查集顯然是可以維護的(每次合並相鄰點成為一個連續區間時新產生的連續區間數可算,就是兩個集合\(size\)的乘積)。
對於子樹外的點怎麽算呢。我們可以先假設子樹外為\(1,2,...,n\),且有這些區間。我們每次在子樹中加入點時,就把它從子樹外的集合的點中刪掉,並計算子樹外少的區間數。
我們發現不需要\(n\)個元素的集合,也不需要這樣刪。在一個初始只有\(0\)\(n+1\)的空集合裏,每次加入子樹內的點,然後計算,也是等價的。
於是我們需要對每個子樹合並並查集及處理set。
可以用dsu on tree,每次先處理完輕兒子的子樹,每次處理完都清空那棵子樹的貢獻/狀態(並查集、set)。最後處理重兒子所在子樹,並保留其狀態。
處理完這整棵子樹後(也就是算完該邊答案後),再把子樹內其它未加入的輕子樹給加入並查集、set。
dsu的復雜度是\(O(n\log n)\)的,再套上別的,復雜度為\(O(n\log^2n)\)

#include <set>
#include <cstdio>
#include <cctype>
#include <algorithm>
//#define gc() getchar()
#define MAXIN 150000
#define gc() (SS==TT&&(TT=(SS=IN)+fread(IN,1,MAXIN,stdin),SS==TT)?EOF:*SS++)
#define Calc(x) (1ll*(x)*(x-1)>>1ll)//區間個數 
typedef long long LL;
const int N=1e5+5;

int n,Enum,H[N],nxt[N<<1],to[N<<1],fa[N],sz[N],son[N],Fa[N],size[N];
bool vis[N];//計算子樹內的連續區間(是否子樹內已存在相鄰的)
LL Ans,sum1,sum2;
std::set<int> st;
char IN[MAXIN],*SS=IN,*TT=IN;

inline int read()
{
    int now=0;register char c=gc();
    for(;!isdigit(c);c=gc());
    for(;isdigit(c);now=now*10+c-'0',c=gc());
    return now;
}
inline void AE(int u,int v)
{
    to[++Enum]=v, nxt[Enum]=H[u], H[u]=Enum;
    to[++Enum]=u, nxt[Enum]=H[v], H[v]=Enum;
}
void DFS1(int x)
{
    int mx=0; sz[x]=1;
    for(int i=H[x],v; i; i=nxt[i])
        if((v=to[i])!=fa[x])
        {
            fa[v]=x, DFS1(v), sz[x]+=sz[v];
            if(sz[v]>mx) mx=sz[v], son[x]=v;
        }
}
int Find(int x)
{
    return x==Fa[x]?x:Fa[x]=Find(Fa[x]);
}
void Upd(int x)
{
    st.insert(x);//我怎麽記得有返回值(iterator)來...
    std::set<int>::iterator it=st.find(x),pre=it,nxt=++it;
    --pre;
    sum2-=Calc(*nxt-*pre-1);//子樹外以前的連續區間被分割 
    sum2+=Calc(x-*pre-1)+Calc(*nxt-x-1);

    vis[x]=1;
    if(vis[x-1])
    {
        int r1=Find(x-1),r2=Find(x);//它們之前顯然不會在一個集合啊(它們之間只會算一次,要麽是加x-1,以前有x;要麽是加x,以前有x-1)。
        sum1+=1ll*size[r1]*size[r2];
        Fa[r1]=r2, size[r2]+=size[r1];
    }
    if(vis[x+1])
    {
        int r1=Find(x+1),r2=Find(x);
        sum1+=1ll*size[r1]*size[r2];
        Fa[r1]=r2, size[r2]+=size[r1];
    }
}
void Clear(int x)
{
    Fa[x]=x, size[x]=1, vis[x]=0;
    for(int i=H[x]; i; i=nxt[i])
        if(to[i]!=fa[x]) Clear(to[i]);
}
void Update(int x)
{
    Upd(x);
    for(int i=H[x]; i; i=nxt[i])
        if(to[i]!=fa[x]) Update(to[i]);
}
void DFS2(int x)
{
    for(int i=H[x],v; i; i=nxt[i])
        if((v=to[i])!=fa[x]&&v!=son[x])
        {
            DFS2(v), Clear(v);
            st.clear(), st.insert(0), st.insert(n+1);
            sum1=0, sum2=Calc(n);//還是在DFS完子樹後就初始化吧 
        }
    if(son[x]) DFS2(son[x]);

    for(int i=H[x],v; i; i=nxt[i])
        if((v=to[i])!=fa[x]&&v!=son[x]) Update(v);

    Upd(x);
    Ans+=Calc(n)-sum1-sum2;
}

int main()
{
    freopen("treecnt.in","r",stdin);
    freopen("treecnt.out","w",stdout);

    n=read();
    for(int i=1; i<n; ++i) AE(read(),read());
    for(int i=1; i<=n; ++i) Fa[i]=i,size[i]=1;//!
    st.insert(0), st.insert(n+1);//邊界.
    sum1=0, sum2=Calc(n), DFS1(1), DFS2(1), printf("%lld\n",Ans);

    return 0;
}

另一種\(O(n^2)\)做法:

#include <set>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
//#define gc() getchar()
#define MAXIN 300000
#define gc() (SS==TT&&(TT=(SS=IN)+fread(IN,1,MAXIN,stdin),SS==TT)?EOF:*SS++)
typedef long long LL;
const int N=1e5+5,INF=0x3f3f3f3f;

int dgr[N],Enum,H[N],nxt[N<<1],to[N<<1],fa[N],dep[N],sz[N],son[N],top[N],dfn[N],ref[N];
char IN[MAXIN],*SS=IN,*TT=IN;

inline int read()
{
    int now=0;register char c=gc();
    for(;!isdigit(c);c=gc());
    for(;isdigit(c);now=now*10+c-'0',c=gc());
    return now;
}
inline void AE(int u,int v)
{
//  ++dgr[u], ++dgr[v];
    to[++Enum]=v, nxt[Enum]=H[u], H[u]=Enum;
    to[++Enum]=u, nxt[Enum]=H[v], H[v]=Enum;
}
inline int LCA(int u,int v)
{
//  printf("LCA(%d,%d)\n",u,v);
    while(top[u]!=top[v]) dep[top[u]]>dep[top[v]]?u=fa[top[u]]:v=fa[top[v]];
    return dep[u]>dep[v]?v:u;
}
inline int Dis(int u,int v)
{
//  printf("(%d,%d) Subd:dep[%d]=%d!\n",u,v,LCA(u,v),dep[LCA(u,v)]);
    return dep[u]+dep[v]-(dep[LCA(u,v)]<<1);
}
void DFS1(int x)
{
    int mx=0; sz[x]=1;
    for(int i=H[x],v; i; i=nxt[i])
        if((v=to[i])!=fa[x])
        {
            fa[v]=x, dep[v]=dep[x]+1, DFS1(v), sz[x]+=sz[v];
            if(sz[v]>mx) mx=sz[v], son[x]=v;
        }
}
void DFS2(int x,int tp)
{
    static int Index=0;

    top[x]=tp, dfn[x]=++Index;
    if(son[x])
    {
        DFS2(son[x],tp);
        for(int i=H[x]; i; i=nxt[i])
            if(to[i]!=fa[x]&&to[i]!=son[x]) DFS2(to[i],to[i]);
    }
}
inline bool cmp_dfn(int i,int j)
{
    return dfn[i]<dfn[j];
}
void Output(int *l,int *r,int n)
{
    printf("\nOutput the list:\n%d",0);
    for(int p=r[0]; p!=n+1; p=r[p]) printf("->%d(%d)",p,l[p]); puts("\n");
}
void Subtask0(int n)
{
    const int N=305;
    static int A[N];

    LL ans=0;
    for(int i=1; i<=n; ++i)
    {
        int t=1; A[1]=i;
        for(int j=i+1; j<=n; ++j)
        {
            A[++t]=j, std::sort(A+1,A+1+t,cmp_dfn);
            for(int k=1; k<t; ++k) ans+=Dis(A[k],A[k+1]);
            ans+=Dis(A[t],A[1]);
        }
    }
    printf("%I64d\n",ans>>1ll);
}
void Subtask1(int n)
{
    const int N=3005;
    static int A[N],SL[N],SR[N],L[N],R[N],id[N],tl[N],tr[N];

    LL ans=0;
    for(int i=1; i<=n; ++i) ans+=1ll*(1ll*(i-1)*(n-i+1)+n-i)*dep[i],id[i]=i;
//  printf("pre_ans=%I64d\n\n",ans);

    std::sort(id+1,id+1+n,cmp_dfn);
    SR[0]=id[1], SL[n+1]=id[n], id[n+1]=n+1;
    for(int i=1; i<=n; ++i) SL[id[i]]=id[i-1], SR[id[i]]=id[i+1];

//  for(int i=1; i<=n; ++i) printf("dfn[%d]=%d\n",i,dfn[i]); puts("");
//  Output(SL,SR,n);

    for(int i=1; i<n; ++i)
    {
        memset(tl,0,sizeof tl), memset(tr,0,sizeof tr);
        memcpy(L,SL,sizeof SL), memcpy(R,SR,sizeof SR);
//      Output(L,R,n);

        for(int j=n,T=1,l,r; j>i; --j,++T)
        {
            int a,b;
            if((l=L[j])==0) a=j, b=L[n+1];
            else a=j, b=l;
            int tmp=std::min(T-tr[b],T-tl[a]);
            ans-=1ll*tmp*dep[LCA(a,b)], tr[b]=T;

            if((r=R[j])==n+1) b=j, a=R[0];
            else b=j, a=r;
            tmp=std::min(T-tr[b],T-tl[a]);
            ans-=1ll*tmp*dep[LCA(a,b)], tl[a]=T;

            R[l]=r, L[r]=l;

//          if((l=L[j])==0) ans-=dep[LCA(j,L[n+1])], printf("Subd:dep[%d]=%d!\n",LCA(j,L[n+1]),dep[LCA(j,L[n+1])]);
//          else ans-=dep[LCA(j,l)], printf("Subd:dep[%d]=%d!\n",LCA(j,l),dep[LCA(j,l)]);
//          if((r=R[j])==n+1) ans-=dep[LCA(j,R[0])], printf("Subd:dep[%d]=%d!\n",LCA(j,R[0]),dep[LCA(j,R[0])]);
//          else ans-=dep[LCA(j,r)], printf("Subd:dep[%d]=%d!\n",LCA(j,r),dep[LCA(j,r)]);
//          R[l]=r, L[r]=l;
        }
        SR[SL[i]]=SR[i], SL[SR[i]]=SL[i];
    }
    printf("%I64d\n",ans);
}
namespace Subtask2
{
//  struct BIT
//  {
//      #define lb(x) (x&-x)
//      int n,mn[N],mx[N];
//  
//      void Init(int nn) {n=nn; memset(mn,0x3f,sizeof mn), memset(mx,0,sizeof mx);}
//      void Modify_Max(int p,int v)
//      {
//          for(; p<=n; p+=lb(p)) mx[p]=std::max(mx[p],v);
//      }
//      void Modify_Min(int p,int v)
//      {
//          for(; p<=n; p+=lb(p)) mn[p]=std::min(mn[p],v);
//      }
//      int Query_Max(int p)
//      {
//          int res=0;
//          for(; p; p^=lb(p)) res=std::max(res,mx[p]);
//          return res;
//      }
//      int Query_Min(int p)
//      {
//          int res=INF;
//          for(; p; p^=lb(p)) res=std::min(res,mn[p]);
//          return res;
//      }
//  }Tp,Ts;
//  int pos[N],ref[N];
//
//  void DFS3(int x,int f,int t)
//  {
////        pos[x]=t, ref[t]=x;
//      pos[t]=x;
//      for(int i=H[x]; i; i=nxt[i])
//          if(to[i]!=f) DFS3(to[i],x,t+1);
//  }
//  bool Main(int n)
//  {
//      for(int i=1; i<=n; ++i) if(dgr[i]>2) return 0;
//      for(int i=1; i<=n; ++i) if(dgr[i]==1) {DFS3(i,i,1); break;}
//
//      LL ans=0;
//      Tp.Init(n), Ts.Init(n);
//      for(int i=n; i; --i)
//      {
//          int p=pos[i],l=Tp.Query_Max(p)+1,r=Ts.Query_Min(n-p+1)-1;
////            if(l==0+1) l=1;
//          if(r==INF-1) r=n;
//          ans+=1ll*i*(p-l+1)*(r-p+1);
//          Tp.Modify_Max(p,p), Ts.Modify_Min(n-p+1,p);
//      }
//  
//      Tp.Init(n), Ts.Init(n);
//      for(int i=1; i<=n; ++i)
//      {
//          int p=pos[i],l=Tp.Query_Max(p)+1,r=Ts.Query_Min(n-p+1)-1;
////            if(l==0+1) l=1;
//          if(r==INF-1) r=n;
//          ans-=1ll*i*(p-l+1)*(r-p+1);
//          Tp.Modify_Max(p,p), Ts.Modify_Min(n-p+1,p);
//      }
//      printf("%I64d\n",ans);
//      return 1;
//  }
}

int main()
{
    freopen("treecnt.in","r",stdin);
    freopen("treecnt.out","w",stdout);

    int n=read();
    for(int i=1; i<n; ++i) AE(read(),read());

//  if(n>3000 && Subtask2::Main(n)) return 0;

    DFS1(1), DFS2(1,1);
//  Subtask0(n); puts("");
    if(n<=300) Subtask0(n);
    else if(n<=3000||1) Subtask1(n);

    return 0;
}

樹上統計treecnt(dsu on tree 並查集 正難則反)