樹鏈剖分(2)樹剖的較高階應用(P3384 【模板】樹鏈剖分)
阿新 • • 發佈:2018-12-16
參照洛谷模板 P3384 【模板】樹鏈剖分 題意:
給你一棵包含n個結點的樹,現要求你支援以下操作:
1.x到y結點最短路徑上所有節點的值都加上z;
2.求樹從x到y結點最短路徑上所有節點的值之和;
3.將以x為根節點的子樹內所有節點值都加上z;
4.求以x為根節點的子樹內所有節點值之和。
正確的思路是樹鏈剖分套線段樹。
我們可以求一個dfs序(dfn[]),這個dfs序與之前求top[x]是同步的。所以這樣就保證:
每一條鏈上的點的dfs序是連續的,每一個點的所有子節點的編號分佈在[dfn[x],dfn[x]+siz[x]-1]之間。
這樣我們就可以以每個點的dfn值作為其新下標,開一棵線段樹維護即可。
對於操作1:在樹剖找LCA(x,y)的過程中不斷對零散區間修改即可。
void LCAu(int x,int y,long long k) { while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]]) swap(x,y); update(1,dfn[top[x]],dfn[x],k);//顯然在一條鏈上的點的dfn值是連續的 x=fa[top[x]]; } if(deep[x]<deep[y]) swap(x,y); update(1,dfn[y],dfn[x],k); }
對於操作2:在樹剖找LCA(x,y)的過程中不斷累加零散區間的值即可。
long long LCAq(int x,int y) { long long ans=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]]) swap(x,y); ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD; x=fa[top[x]]; } if(deep[x]<deep[y]) swap(x,y); ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD; return ans; }
對於操作3:修改dfn值屬於[dfn[x],dfn[x]+siz[x]-1]的結點的權值即可。
對於操作4:查詢dfn值屬於[dfn[x],dfn[x]+siz[x]-1]的結點的權值的和即可。
完整程式碼:
#include<cstdio>
#include<iostream>
#define ri register int
using namespace std;
const int MAXN=200020;
int n,m,s,q,u[MAXN],v[MAXN],fst[MAXN],nxt[MAXN],key,xl,xto;
long long MOD,w[MAXN],a[MAXN],num;
int fa[MAXN],deep[MAXN],siz[MAXN],cmax[MAXN],son[MAXN],top[MAXN],cur,dfn[MAXN];
int l[MAXN<<2],r[MAXN<<2];
long long sum[MAXN<<2],len[MAXN<<2],tag[MAXN<<2];
void dfs1(int x,int father,int dep)
{
fa[x]=father,deep[x]=dep,siz[x]=1;
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=father)
{
dfs1(v[k],x,dep+1);
if(siz[v[k]]>cmax[x]) cmax[x]=siz[v[k]],son[x]=v[k];
siz[x]+=siz[v[k]];
}
}
void dfs2(int x,int anc)
{
top[x]=anc,dfn[x]=++cur,a[cur]=w[x];
if(son[x]) dfs2(son[x],anc);
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=fa[x]&&v[k]!=son[x]) dfs2(v[k],v[k]);
}
void pushup(int p)
{
sum[p]=(sum[p <<1]+sum[p <<1|1]+MOD)%MOD;
}
void pushdown(int p)
{
sum[p <<1]=(sum[p <<1]+(len[p <<1]*tag[p])%MOD+MOD)%MOD;
tag[p <<1]=(tag[p <<1]+tag[p]+MOD)%MOD;
sum[p <<1|1]=(sum[p <<1|1]+(len[p <<1|1]*tag[p])%MOD+MOD)%MOD;
tag[p <<1|1]=(tag[p <<1|1]+tag[p]+MOD)%MOD;
tag[p]=0;
}
void build(int p,int lft,int rit)
{
l[p]=lft,r[p]=rit;
if(lft==rit)
{
sum[p]=a[lft],len[p]=1;
return;
}
int mid=(lft+rit)>>1;
build(p <<1,lft,mid);
build(p <<1|1,mid+1,rit);
pushup(p);
len[p]=len[p <<1]+len[p <<1|1];
}
void update(int p,int lft,int rit,long long k)
{
if(lft<=l[p]&&r[p]<=rit)
{
sum[p]=(sum[p]+(len[p]*k)%MOD+MOD)%MOD,tag[p]=(tag[p]+k+MOD)%MOD;
return;
}
pushdown(p);
if(lft<=r[p <<1]) update(p <<1,lft,rit,k);
if(l[p <<1|1]<=rit) update(p <<1|1,lft,rit,k);
pushup(p);
}
long long query(int p,int lft,int rit)
{
if(lft<=l[p]&&r[p]<=rit) return sum[p];
long long ans=0;
pushdown(p);
if(lft<=r[p <<1]) ans=query(p <<1,lft,rit);
if(l[p <<1|1]<=rit) ans=(ans+query(p <<1|1,lft,rit)+MOD)%MOD;
return ans;
}
void LCAu(int x,int y,long long k)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
update(1,dfn[top[x]],dfn[x],k);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
update(1,dfn[y],dfn[x],k);
}
long long LCAq(int x,int y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD;
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD;
return ans;
}
int main()
{
scanf("%d%d%d%lld",&n,&q,&s,&MOD);
for(ri i=1;i<=n;i++) scanf("%lld",&w[i]);
m=(n-1)<<1;
for(ri i=1;i<=m;i+=2)
{
scanf("%d%d",&u[i],&v[i]);
nxt[i]=fst[u[i]],fst[u[i]]=i;
u[i+1]=v[i],v[i+1]=u[i];
nxt[i+1]=fst[u[i+1]],fst[u[i+1]]=i+1;
}
dfs1(s,s,0);
dfs2(s,s);
build(1,1,n);
for(ri i=1;i<=q;i++)
{
scanf("%d%d",&key,&xl);
if(key==1)
{
scanf("%d%lld",&xto,&num);
LCAu(xl,xto,num);
}
if(key==2)
{
scanf("%d",&xto);
cout<<LCAq(xl,xto)<<'\n';
}
if(key==3)
{
scanf("%lld",&num);
update(1,dfn[xl],dfn[xl]+siz[xl]-1,num);
}
if(key==4) cout<<query(1,dfn[xl],dfn[xl]+siz[xl]-1)<<'\n';
}
return 0;
}