【NIOIP2016提高】天天愛跑步(LCA+樹上差分)
近幾年複賽最難的樹上問題了。
幾個月前做是參照題解的方法,用了可持久化線段樹在樹上無腦維護和統計。
當時的做法早已忘記,於是回過來自己做了做,其實遠沒有那麼難做,只要發現一些奇妙的性質。
對於一個玩家s->t,如圖。
對於圖中a點的觀察員存在這樣一個式子:w(a)=dep(a)-dep(lca)+dep(s)-dep(lca)。
對於圖中b點的觀察員存在這樣一個式子:w(b)=(dep(s)-dep(lca))-(dep(b)-dep(lca))。
轉化一下,對於a,有:w(a)-dep(a)=dep(s)-2*dep(lca)。也就是對於lca往下走的路徑上的點滿足這個性質。
對於b,有:w(b)+dep(b)=dep(s)。也就是對於s往上走的路徑上的點滿足這個性質。
於是看似很麻煩的統計問題被化簡了:等式左端只與點i本身有關,右端對於每個玩家是一個定值,“時間”的影響被去掉了。
於是統計的話,就是s->t這條路徑上所有滿足上式的a,b。
可以想見,統計答案就是走到一個點i,然後統計目前已有多少個“w(i)-dep(i)”,以及多少個“w(i)+dep(i)”。
利用差分的思想,對於從lca往下走的路徑,在t處將“dep(s)-2*dep(lca)”的計數加一,在lca處將其計數減一,那麼這一段路上遍
歷到每個節點時,滿足上述式子的加上對應的計數即可。
同理,對於從s往上走的路徑,在s處將“dep(s)”的計數加一,在lca處將其計數減一。
開三個變長陣列,一個名為work,用於存所有從v往lca走的“dep(s)-2*dep(lca)”以及lca,一個名為work2,用於存所有從s往上走
到lca的“dep(s)”以及lca。
最後一個名為del,用於存當前節點需要把哪一個值的計數減一。
以work2為例,我們後序遍歷整棵樹,假如現在遍歷到了i,我們先遍歷它的一個兒子j,然後當兒子j遞歸回來後,i的答案加
上“w(i)+dep(i)”和“w(i)-dep(i)”遍歷j前後計數之差。
然後將i的所有del進行操作,並把i的del清空。接下來遍歷下一個兒子。進行類似的操作。
當所有的兒子遍歷完,再來對i的work2操作,每次依然是加上對應計數前後之差,並且每操作一個work2,就在對應的lca的del加
入本次操作的值,等回到lca時就執行del操作。
work與此完全一致,之所以要分開統計,是為了避免“w(i)-dep(i)”與“w(i)+dep(i)”相同的情況。
特殊情況:若lca滿足”w(lca)+dep(lca)=dep(s)“和“w(lca)-dep(lca)=dep(s)-2*dep(lca)”中的一個,那麼必然兩個都滿足,所以在加入
work和work2時要特判減一。
注:此程式碼用了樹剖求LCA,理論上可以節約時間空間。
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN=300005;
int N,M;
int np=0;
int w[MAXN];
int fa[MAXN];
int son[MAXN];
int top[MAXN];
int dep[MAXN];
int size[MAXN];
int last[MAXN];
int ans[MAXN];
int cnt[MAXN<<2];
struct edge{
int to,pre;
}E[MAXN<<1];
struct data{
int v,to;
};
vector<data>work[MAXN];
vector<data>work2[MAXN];
vector<int>del[MAXN];
char c,num[20];int ct;
void scan(int &x){
for(c=getchar();c<'0'||c>'9';c=getchar());
for(x=0;c>='0'&&c<='9';c=getchar())x=x*10+c-'0';
}
void print(int x){
ct=0;
if(!x)num[ct++]='0';
while(x){num[ct++]=x%10+'0',x/=10;}
while(ct--)putchar(num[ct]);
putchar(' ');
}
void add(int u,int v){
E[++np]=(edge){v,last[u]};
last[u]=np;
}
void dfs1(int x){
size[x]=1;
for(int p=last[x];p;p=E[p].pre){
int j=E[p].to;
if(j==fa[x])continue;
dep[j]=dep[x]+1; fa[j]=x;
dfs1(j); size[x]+=size[j];
if(size[j]>size[son[x]])son[x]=j;
}
}
void dfs2(int x,int tp){
top[x]=tp;
if(!son[x])return;
dfs2(son[x],tp);
for(int p=last[x];p;p=E[p].pre){
int j=E[p].to;
if(j==fa[x]||j==son[x])continue;
dfs2(j,j);
}
}
int LCA(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
u=fa[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
return u;
}
void calc(int x,int f){
int sz;
int ct1=cnt[w[x]-dep[x]+MAXN];//ct1是對應計數上一次的數量
for(int p=last[x];p;p=E[p].pre){
int j=E[p].to;
if(j==f)continue;
calc(j,x);//先遍歷
ans[x]+=cnt[w[x]-dep[x]+MAXN]-ct1;//加上前後之差
sz=del[x].size();//執行del操作
for(int i=0;i<sz;i++)
cnt[del[x][i]+MAXN]--;
del[x].clear();//清空del陣列
ct1=cnt[w[x]-dep[x]+MAXN];//更新ct1
}
sz=work[x].size();//執行work操作,同時新增del操作
for(int i=0;i<sz;i++){
cnt[work[x][i].v+MAXN]++;
ans[x]+=cnt[w[x]-dep[x]+MAXN]-ct1;
del[work[x][i].to].push_back(work[x][i].v);//新增del操作
int sz2=del[x].size();//依然要執行del操作
for(int i=0;i<sz2;i++)
cnt[del[x][i]+MAXN]--;
del[x].clear();
ct1=cnt[w[x]-dep[x]+MAXN];
}
work[x].clear();
}void calc2(int x,int f){//完全相同,只是計算種類有區別
int sz;
int ct2=cnt[w[x]+dep[x]+MAXN];
for(int p=last[x];p;p=E[p].pre){
int j=E[p].to;
if(j==f)continue;
calc2(j,x);
ans[x]+=cnt[w[x]+dep[x]+MAXN]-ct2;
sz=del[x].size();
for(int i=0;i<sz;i++)
cnt[del[x][i]+MAXN]--;
del[x].clear();
ct2=cnt[w[x]+dep[x]+MAXN];
}
sz=work2[x].size();
for(int i=0;i<sz;i++){
cnt[work2[x][i].v+MAXN]++;
ans[x]+=cnt[w[x]+dep[x]+MAXN]-ct2;
del[work2[x][i].to].push_back(work2[x][i].v);
int sz2=del[x].size();
for(int i=0;i<sz2;i++)
cnt[del[x][i]+MAXN]--;
del[x].clear();
ct2=cnt[w[x]+dep[x]+MAXN];
}
work2[x].clear();
}
int main(){
scan(N);scan(M);
for(int u,v,i=1;i<N;i++){
scan(u);scan(v);
add(u,v);add(v,u);
}
for(int i=1;i<=N;i++)scan(w[i]);
dep[1]=1;dfs1(1);dfs2(1,1); //樹剖求LCA優化
for(int u,v,lca,i=1;i<=M;i++){
scan(u);scan(v);lca=LCA(u,v);
work[v].push_back((data){dep[u]-2*dep[lca],lca});//從v走向lca
work2[u].push_back((data){dep[u],lca});//從u走向lca
if(dep[u]==w[lca]+dep[lca])ans[lca]-=1;//特判
}
calc(1,0); calc2(1,0);//分別計數
for(int i=1;i<=N;i++)print(ans[i]);
return 0;
}