1. 程式人生 > >樹鏈剖分(2)樹剖的較高階應用(P3384 【模板】樹鏈剖分)

樹鏈剖分(2)樹剖的較高階應用(P3384 【模板】樹鏈剖分)

參照洛谷模板 P3384 【模板】樹鏈剖分 題意:

給你一棵包含n個結點的樹,現要求你支援以下操作:

1.x到y結點最短路徑上所有節點的值都加上z;

2.求樹從x到y結點最短路徑上所有節點的值之和;

3.將以x為根節點的子樹內所有節點值都加上z;

4.求以x為根節點的子樹內所有節點值之和。

正確的思路是樹鏈剖分套線段樹。

我們可以求一個dfs序(dfn[]),這個dfs序與之前求top[x]是同步的。所以這樣就保證:

每一條鏈上的點的dfs序是連續的,每一個點的所有子節點的編號分佈在[dfn[x],dfn[x]+siz[x]-1]之間。

這樣我們就可以以每個點的dfn值作為其新下標,開一棵線段樹維護即可。

對於操作1:在樹剖找LCA(x,y)的過程中不斷對零散區間修改即可。

void LCAu(int x,int y,long long k)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]])	swap(x,y);
        update(1,dfn[top[x]],dfn[x],k);//顯然在一條鏈上的點的dfn值是連續的 
        x=fa[top[x]];
    }
    if(deep[x]<deep[y])	swap(x,y);
    update(1,dfn[y],dfn[x],k);
}

對於操作2:在樹剖找LCA(x,y)的過程中不斷累加零散區間的值即可。

long long LCAq(int x,int y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]])	swap(x,y);
        ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD;
        x=fa[top[x]];
    }
    if(deep[x]<deep[y])	 swap(x,y);
    ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD;
    return ans;
}

對於操作3:修改dfn值屬於[dfn[x],dfn[x]+siz[x]-1]的結點的權值即可。

對於操作4:查詢dfn值屬於[dfn[x],dfn[x]+siz[x]-1]的結點的權值的和即可。

完整程式碼:

#include<cstdio>
#include<iostream>
#define ri register int
using namespace std;

const int MAXN=200020;
int n,m,s,q,u[MAXN],v[MAXN],fst[MAXN],nxt[MAXN],key,xl,xto;
long long MOD,w[MAXN],a[MAXN],num;
int fa[MAXN],deep[MAXN],siz[MAXN],cmax[MAXN],son[MAXN],top[MAXN],cur,dfn[MAXN];
int l[MAXN<<2],r[MAXN<<2];
long long sum[MAXN<<2],len[MAXN<<2],tag[MAXN<<2];

void dfs1(int x,int father,int dep)
{
    fa[x]=father,deep[x]=dep,siz[x]=1;
    for(ri k=fst[x];k>0;k=nxt[k])
        if(v[k]!=father)
        {
            dfs1(v[k],x,dep+1);
            if(siz[v[k]]>cmax[x])	cmax[x]=siz[v[k]],son[x]=v[k];
            siz[x]+=siz[v[k]];
        }
}

void dfs2(int x,int anc)
{
    top[x]=anc,dfn[x]=++cur,a[cur]=w[x];
    if(son[x])	   dfs2(son[x],anc);
    for(ri k=fst[x];k>0;k=nxt[k])
        if(v[k]!=fa[x]&&v[k]!=son[x])	dfs2(v[k],v[k]);
}

void pushup(int p)
{
    sum[p]=(sum[p <<1]+sum[p <<1|1]+MOD)%MOD;
}

void pushdown(int p)
{
    sum[p <<1]=(sum[p <<1]+(len[p <<1]*tag[p])%MOD+MOD)%MOD;
    tag[p <<1]=(tag[p <<1]+tag[p]+MOD)%MOD;
    sum[p <<1|1]=(sum[p <<1|1]+(len[p <<1|1]*tag[p])%MOD+MOD)%MOD;
    tag[p <<1|1]=(tag[p <<1|1]+tag[p]+MOD)%MOD;
    tag[p]=0;
}

void build(int p,int lft,int rit)
{
    l[p]=lft,r[p]=rit;
    if(lft==rit)
    {
        sum[p]=a[lft],len[p]=1;
        return;
    }
    int mid=(lft+rit)>>1;
    build(p <<1,lft,mid);
    build(p <<1|1,mid+1,rit);
    pushup(p);
    len[p]=len[p <<1]+len[p <<1|1];
}

void update(int p,int lft,int rit,long long k)
{
    if(lft<=l[p]&&r[p]<=rit)
    {
        sum[p]=(sum[p]+(len[p]*k)%MOD+MOD)%MOD,tag[p]=(tag[p]+k+MOD)%MOD;
        return;
    }
    pushdown(p);
    if(lft<=r[p <<1])    update(p <<1,lft,rit,k);
    if(l[p <<1|1]<=rit)	   update(p <<1|1,lft,rit,k);
    pushup(p);
}

long long query(int p,int lft,int rit)
{
    if(lft<=l[p]&&r[p]<=rit)	return sum[p];
    long long ans=0;
    pushdown(p);
    if(lft<=r[p <<1])	ans=query(p <<1,lft,rit);
    if(l[p <<1|1]<=rit)	   ans=(ans+query(p <<1|1,lft,rit)+MOD)%MOD;
    return ans;
}

void LCAu(int x,int y,long long k)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]])	swap(x,y);
        update(1,dfn[top[x]],dfn[x],k);
        x=fa[top[x]];
    }
    if(deep[x]<deep[y])	swap(x,y);
    update(1,dfn[y],dfn[x],k);
}

long long LCAq(int x,int y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]])	swap(x,y);
        ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD;
        x=fa[top[x]];
    }
    if(deep[x]<deep[y])	 swap(x,y);
    ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD;
    return ans;
}

int main()
{
    scanf("%d%d%d%lld",&n,&q,&s,&MOD);
    for(ri i=1;i<=n;i++)	scanf("%lld",&w[i]);
    m=(n-1)<<1;
    for(ri i=1;i<=m;i+=2)
    {
        scanf("%d%d",&u[i],&v[i]);
        nxt[i]=fst[u[i]],fst[u[i]]=i;
        u[i+1]=v[i],v[i+1]=u[i];
        nxt[i+1]=fst[u[i+1]],fst[u[i+1]]=i+1;
    }
    dfs1(s,s,0);
    dfs2(s,s);
    build(1,1,n);
    for(ri i=1;i<=q;i++)
    {
        scanf("%d%d",&key,&xl);
        if(key==1)
        {
            scanf("%d%lld",&xto,&num);
            LCAu(xl,xto,num);
        }	
        if(key==2) 
        {
            scanf("%d",&xto);
            cout<<LCAq(xl,xto)<<'\n';
        }
        if(key==3)
        {
            scanf("%lld",&num);
          	update(1,dfn[xl],dfn[xl]+siz[xl]-1,num);
        }
        if(key==4)	cout<<query(1,dfn[xl],dfn[xl]+siz[xl]-1)<<'\n';
    }
    return 0;
}