1. 程式人生 > >bzoj 3862: Little Devil I (樹鏈剖分+線段樹)

bzoj 3862: Little Devil I (樹鏈剖分+線段樹)

題目描述

傳送門

題目大意:給出一棵n個點的樹,有三種操作。
操作1:把x,y路徑上所有邊反色
操作2:把x,y路徑上所有相鄰的邊反色,即一個點在路徑上
操作3:詢問x,y路徑上黑邊的個數。
注意剛開始的時候所有邊均為白色。

題解

操作1,3都是基本的操作,關鍵就是2.
對於每個點維護這個點的輕兒子是否要反色。每次修改的時候直接區間修改即可。兩條重鏈相連的輕邊需要特判。路徑的頂點到他父親之間的邊也需要特判。
每次計算答案的時候,鏈頂的顏色都要結合上父親的輕兒子是否反色來計算。

程式碼

#include<iostream>
#include<cstdio>
#include<cstring> #include<algorithm> #include<cmath> #define N 200003 using namespace std; int point[N],v[N],nxt[N],deep[N],fa[N],size[N],son[N],belong[N]; int pos[N],q[N],cnt,h[N]; int tot,n,m,T; struct data{ int w,b,delta,rev; }tr[N*4]; void add(int x,int y) { tot++; nxt[tot]=point[x
]; point[x]=tot; v[tot]=y; tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; } void dfs(int x,int f) { size[x]=1; son[x]=0; deep[x]=deep[f]+1; for (int i=point[x];i;i=nxt[i]){ if (v[i]==f) continue; fa[v[i]]=x; dfs(v[i],x); size[x]+=size[v[i]]; if
(size[v[i]]>size[son[x]]) son[x]=v[i]; } } void dfs1(int x,int chain) { pos[x]=++cnt; q[cnt]=x; belong[x]=chain; if (!son[x]) return; dfs1(son[x],chain); for (int i=point[x];i;i=nxt[i]) if (v[i]!=son[x]&&v[i]!=fa[x]) dfs1(v[i],v[i]); } void update(data &now,data l,data r) { now.w=l.w+r.w; now.b=l.b+r.b; } void clear(int now) { tr[now].b=tr[now].w=tr[now].rev=tr[now].delta=0; } void build(int now,int l,int r) { clear(now); if (l==r) { tr[now].b=0; if (l!=1) tr[now].w=1; return; } int mid=(l+r)/2; build(now<<1,l,mid); build(now<<1|1,mid+1,r); update(tr[now],tr[now<<1],tr[now<<1|1]); } void change(int now) { swap(tr[now].w,tr[now].b); tr[now].delta^=1; } void pushdown(int now) { if (tr[now].delta){ change(now<<1); change(now<<1|1); tr[now].delta=0; } if (tr[now].rev) { tr[now<<1].rev^=1; tr[now<<1|1].rev^=1; tr[now].rev=0; } } data query(int now,int l,int r,int ll,int rr) { if (ll<=l&&r<=rr) return tr[now]; pushdown(now); int mid=(l+r)/2; data ans; bool pd=false; if (ll<=mid) ans=query(now<<1,l,mid,ll,rr),pd=true; if (rr>mid) { if (pd) update(ans,ans,query(now<<1|1,mid+1,r,ll,rr)); else ans=query(now<<1|1,mid+1,r,ll,rr); } return ans; } void qjchange(int now,int l,int r,int ll,int rr) { if (ll<=l&&r<=rr) { change(now); return; } int mid=(l+r)/2; pushdown(now); if (ll<=mid) qjchange(now<<1,l,mid,ll,rr); if (rr>mid) qjchange(now<<1|1,mid+1,r,ll,rr); update(tr[now],tr[now<<1],tr[now<<1|1]); } void reverse(int now,int l,int r,int ll,int rr) { if (ll<=l&&r<=rr) { tr[now].rev^=1; return; } pushdown(now); int mid=(l+r)/2; if (ll<=mid) reverse(now<<1,l,mid,ll,rr); if (rr>mid) reverse(now<<1|1,mid+1,r,ll,rr); update(tr[now],tr[now<<1],tr[now<<1|1]); } void solve(int x,int y) { while (belong[x]!=belong[y]) { if (deep[belong[x]]<deep[belong[y]]) swap(x,y); qjchange(1,1,n,pos[belong[x]],pos[x]); x=fa[belong[x]]; } if (deep[x]>deep[y]) swap(x,y); if (x==y) return; qjchange(1,1,n,pos[x]+1,pos[y]); } void paint(int x,int y) { bool pd=false; int t=0; while (belong[x]!=belong[y]){ if (deep[belong[x]]<deep[belong[y]]) swap(x,y); reverse(1,1,n,pos[belong[x]],pos[x]); if (son[x]) qjchange(1,1,n,pos[son[x]],pos[son[x]]); qjchange(1,1,n,pos[belong[x]],pos[belong[x]]); t=belong[x]; x=fa[belong[x]]; h[x]=t; } if (deep[x]>deep[y]) swap(x,y); reverse(1,1,n,pos[x],pos[y]); qjchange(1,1,n,pos[x],pos[x]); if (son[y]) qjchange(1,1,n,pos[son[y]],pos[son[y]]); } int find(int now,int l,int r,int x) { if (x==0) return 0; if (l==r) return tr[now].rev; int mid=(l+r)/2; pushdown(now); if (x<=mid) return find(now<<1,l,mid,x); else return find(now<<1|1,mid+1,r,x); } int calc(int x,int y) { int ans=0; while (belong[x]!=belong[y]) { if (deep[belong[x]]<deep[belong[y]]) swap(x,y); data t=query(1,1,n,pos[belong[x]],pos[x]); data t1=query(1,1,n,pos[belong[x]],pos[belong[x]]); int k=find(1,1,n,pos[fa[belong[x]]]); if (t1.b==1&&k) t.b--; if (t1.w==1&&k) t.b++; ans+=t.b; x=fa[belong[x]]; } if (deep[x]>deep[y]) swap(x,y); if (x==y) return ans; data t=query(1,1,n,pos[x]+1,pos[y]); ans+=t.b; return ans; } int read() { char c; int x=0; c=getchar(); while (c>'9'||c<'0') c=getchar(); while (c<='9'&&c>='0') x=x*10+c-'0',c=getchar(); return x; } int main() { freopen("a.in","r",stdin); freopen("my.out","w",stdout); scanf("%d",&T); while (T--){ tot=0; cnt=0; memset(point,0,sizeof(point)); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y; x=read(); y=read(); add(x,y); } dfs(1,0); dfs1(1,1); build(1,1,n); scanf("%d",&m); for (int i=1;i<=m;i++) { //cout<<i<<endl; int opt,x,y; opt=read(); x=read(); y=read(); if (opt==1) solve(x,y); if (opt==2) paint(x,y); if (opt==3) printf("%d\n",calc(x,y)); } } }