洛谷 P2590 [ZJOI2008]樹的統計(樹鏈剖分+線段樹)
阿新 • • 發佈:2018-12-10
題目描述
一棵樹上有n個節點,編號分別為1到n,每個節點都有一個權值w。
我們將以下面的形式來要求你對這棵樹完成一些操作:
I. CHANGE u t : 把結點u的權值改為t
II. QMAX u v: 詢問從點u到點v的路徑上的節點的最大權值
III. QSUM u v: 詢問從點u到點v的路徑上的節點的權值和
注意:從點u到點v的路徑上的節點包括u和v本身
輸入輸出格式
輸入格式:
輸入檔案的第一行為一個整數n,表示節點的個數。
接下來n – 1行,每行2個整數a和b,表示節點a和節點b之間有一條邊相連。
接下來一行n個整數,第i個整數wi表示節點i的權值。
接下來1行,為一個整數q,表示操作的總數。
接下來q行,每行一個操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式給出。
輸出格式:
對於每個“QMAX”或者“QSUM”的操作,每行輸出一個整數表示要求輸出的結果。
題解:樹鏈剖分的模板題,第一遍dfs確認輕重孩子,第二遍dfs拉出輕重鏈,每一條輕重鏈便成了一條條的序列,然後用線段樹維護區間最大值和區間和即可。
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+7; #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 int _dfs[maxn],son[maxn],far[maxn],siz[maxn],sum[maxn]; int dep[maxn],tot,top[maxn],n,a[maxn],id[maxn],tree[maxn]; vector<int>G[maxn]; void dfs1(int u,int fa,int depth) { far[u]=fa; siz[u]=1; dep[u]=depth; int sz=G[u].size(); for(int i=0;i<sz;i++) { int v=G[u][i]; if(v==fa) continue; dfs1(v,u,depth+1); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void dfs2(int x,int t) { _dfs[x]=++tot; top[x]=t; id[tot]=x; if(son[x]) dfs2(son[x],t); int sz=G[x].size(); for(int i=0;i<sz;i++) { int v=G[x][i]; if(v!=far[x] && v!=son[x]) dfs2(v,v); } } void build(int l=1,int r=n,int rt=1) { if(l==r) { tree[rt]=a[id[l]]; sum[rt]=a[id[l]]; return; } int m=(l+r)>>1; build(lson);build(rson); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; tree[rt]=max(tree[rt<<1],tree[rt<<1|1]); } int querys(int L,int R,int l=1,int r=n,int rt=1) { if(l>=L && r<=R) return sum[rt]; int m=(l+r)>>1,ans=0; if(L<=m) ans+=querys(L,R,lson); if(R>m) ans+=querys(L,R,rson); return ans; } int querym(int L,int R,int l=1,int r=n,int rt=1) { if(l>=L && r<=R) return tree[rt]; int m=(l+r)>>1,ans=-0x3f3f3f3f; if(L<=m) ans=max(ans,querym(L,R,lson)); if(R>m) ans=max(ans,querym(L,R,rson)); return ans; } void update(int o,int v,int l=1,int r=n,int rt=1) { if(l==r) { tree[rt]=sum[rt]=v; return; } int m=(l+r)>>1; if(o<=m) update(o,v,lson); else update(o,v,rson); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; tree[rt]=max(tree[rt<<1],tree[rt<<1|1]); } int calm(int u,int v) { int ans=-0x3f3f3f3f; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans=max(ans,querym(_dfs[top[u]],_dfs[u])); u=far[top[u]]; } if(dep[u]>dep[v]) swap(u,v); ans=max(ans,querym(_dfs[u],_dfs[v])); return ans; } int cals(int u,int v) { int ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans+=querys(_dfs[top[u]],_dfs[u]); u=far[top[u]]; } if(dep[u]>dep[v]) swap(u,v); ans+=querys(_dfs[u],_dfs[v]); return ans; } int main() { scanf("%d",&n); for(int i=0;i<n-1;i++) { int u,v; scanf("%d%d",&u,&v); G[u].push_back(v); G[v].push_back(u); } for(int i=1;i<=n;i++) scanf("%d",&a[i]); dfs1(1,1,0);dfs2(1,1); build(); int q;scanf("%d",&q); while(q--) { char s[15]={0}; scanf("%s",s); int u,v; scanf("%d%d",&u,&v); if(s[0]=='C') update(_dfs[u],v); else if(s[1]=='M') { int ans=calm(u,v); printf("%d\n",ans); } else { int ans=cals(u,v); printf("%d\n",ans); } } return 0; }