1. 程式人生 > >【bzoj1036】樹的統計Count【樹鏈剖分】【ZKW大法好】【卡常大法好】

【bzoj1036】樹的統計Count【樹鏈剖分】【ZKW大法好】【卡常大法好】

關於這個樹上路徑端點會重合的問題,我們只要不判斷x==y就行了。詳見被註釋呵呵的地方。

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int maxn=30001;
typedef int arr[maxn];
typedef int arr1[maxn<<1];
arr fa,top,dep,size,son,idx;
arr list;
arr1 next,to;
struct node{int sum,max;}t[65538
]; int M,z,n,tot; const int inf=30001; node U(const node &a,const node &b){ return (node){ a.sum+b.sum, max(a.max,b.max) }; } void change(int s,int w){ for(t[s].max=t[s].sum=w,s>>=1;s;s>>=1) t[s]=U(t[s<<1],t[s<<1|1]); } int Max(int l,int r){ int
lans=-inf,rans=-inf; for(l+=M-1,r+=M+1;l^r^1;l>>=1,r>>=1){ if(~l&1) lans=max(lans,t[l^1].max); if( r&1) rans=max(rans,t[r^1].max); } return max(lans,rans); } int Sum(int l,int r){ int ans=0; for(l+=M-1,r+=M+1;l^r^1;l>>=1,r>>=1){ if
(~l&1) ans+=t[l^1].sum; if( r&1) ans+=t[r^1].sum; } return ans; } void dfs1(int x){ size[x]=1;son[x]=0; for(int k=list[x];k;k=next[k]) if(to[k]!=fa[x]){ fa[to[k]]=x; dep[to[k]]=dep[x]+1; dfs1(to[k]); size[x]+=size[to[k]]; if(size[to[k]]>size[son[x]]) son[x]=to[k]; } } void dfs2(int x,int tp){ top[x]=tp; idx[x]=++z; if(son[x]) dfs2(son[x],tp); for(int k=list[x];k;k=next[k]) if(to[k]!=fa[x]&&to[k]!=son[x]) dfs2(to[k],to[k]); } int findSum(int x,int y){ int ans=0,tpx=top[x],tpy=top[y]; while(tpx!=tpy){ if(dep[tpx]<dep[tpy]) swap(tpx,tpy),swap(x,y); ans+=Sum(idx[tpx],idx[x]); x=fa[tpx]; tpx=top[x]; } if(dep[x]<dep[y]) swap(x,y);//呵呵 return ans+Sum(idx[y],idx[x]); } int findMax(int x,int y){ int ans=-inf,tpx=top[x],tpy=top[y]; while(tpx!=tpy){ if(dep[tpx]<dep[tpy]) swap(tpx,tpy),swap(x,y); ans=max(ans,Max(idx[tpx],idx[x])); x=fa[tpx]; tpx=top[x]; } if(dep[x]<dep[y]) swap(x,y);//呵呵 return max(ans,Max(idx[y],idx[x])); } inline int read(){ int x=0; char ch=getchar(); bool f=0; while(!isdigit(ch)){if(ch=='-') f=1;ch=getchar();} while(isdigit(ch)) x=x*10+ch-48,ch=getchar(); return f?-x:x; } inline void add(int a,int b){ next[++tot]=list[a]; list[a]=tot; to[tot]=b; } void init(){ n=read(); M=1;while(M<=n) M<<=1; int x,y; for(int i=1;i<n;++i){ x=read();y=read(); add(x,y);add(y,x); } dfs1(1);dfs2(1,1); for(int i=1;i<=n;++i){ x=read(); t[idx[i]+M]=(node){x,x}; } for(int i=M-1;i;--i) t[i]=U(t[i<<1],t[i<<1|1]); } int a[10]; inline void print(int x){ if(x<0) putchar('-'),x=-x; if(x) a[0]=0; else a[1]=0,a[0]=1; while(x) a[++a[0]]=x%10,x/=10; for(;a[0];--a[0]) putchar(a[a[0]]+48); putchar('\n'); } void work(){ char cmd[10]; int q=read(),u,v; while(q--){ scanf("%s",cmd); if(cmd[1]=='M'){ u=read();v=read(); print(findMax(u,v)); } else if(cmd[1]=='S'){ u=read();v=read(); print(findSum(u,v)); } else{ u=read();v=read(); change(idx[u]+M,v); } } } int main(){ init(); work(); return 0; }