COT2 Count on a tree II【樹上莫隊】
Time limit 1207 ms Memory limit 1572864 kB
You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M. (N <= 40000, M <= 100000)
In the second line there are N integers. The i-th integer denotes the weight of the i-th node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation, print its result.
題目分析
#include<iostream>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#include<cstdio>
using namespace std;
typedef long long lt;
int read()
{
int f=1,x=0;
char ss=getchar();
while (ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
const int maxn=100010;
int n,m,t;
struct edge{int v,nxt;}E[maxn<<1];
int head[maxn],tot;
int a[maxn],b[maxn],pot[maxn],num;
int pos[maxn],cnt[maxn],L=1,R;
int fa[maxn],son[maxn],size[maxn];
int dep[maxn],top[maxn],st[maxn],ed[maxn];
struct node{int ll,rr,lca,num;}q[maxn];
bool cmp(node a,node b){return (a.ll/t==b.ll/t) ?a.rr<b.rr :(a.ll/t<b.ll/t);}
int rem[maxn],ans[maxn],res;
void add(int u,int v)
{
E[++tot].nxt=head[u];
E[tot].v=v;
head[u]=tot;
}
void dfs1(int u,int pa)
{
size[u]=1; st[u]=++num; pot[num]=u;
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==pa) continue;
fa[v]=u; dep[v]=dep[u]+1;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
ed[u]=++num; pot[num]=u;
}
void dfs2(int u,int tp)
{
top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int LCA(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]>dep[top[v]]) u=fa[top[u]];
else v=fa[top[v]];
}
return dep[u]<dep[v]?u:v;
}
void add(int x){ if(++cnt[x]==1)res++;}
void del(int x){ if(--cnt[x]==0)res--;}
void cal(int u)
{
if(rem[u]) del(a[u]);
else add(a[u]); rem[u]^=1;
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;++i)
a[i]=b[i]=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v); add(v,u);//cout<<"hh"<<endl;
}
sort(b+1,b+1+n);
for(int i=1;i<=n;++i)
if(i==1||b[i]!=b[i-1])
pos[++pos[0]]=b[i];
for(int i=1;i<=n;++i)
a[i]=lower_bound(pos+1,pos+1+pos[0],a[i])-pos;
dep[1]=1;
dfs1(1,0); dfs2(1,1);
for(int i=1;i<=m;++i)
{
int u=read(),v=read();
if(st[u]>st[v]) swap(u,v);
int lca=LCA(u,v); q[i].num=i;
if(lca==u){ q[i].ll=st[u]; q[i].rr=st[v]; q[i].lca=0;}
else{ q[i].ll=ed[u]; q[i].rr=st[v]; q[i].lca=lca;}
}
t=sqrt(n*2);
sort(q+1,q+1+m,cmp);
for(int i=1;i<=m;++i)
{
while(R<q[i].rr) cal(pot[++R]);
while(R>q[i].rr) cal(pot[R--]);
while(L<q[i].ll) cal(pot[L++]);
while(L>q[i].ll) cal(pot[--L]);
if(q[i].lca!=0) cal(q[i].lca);
ans[q[i].num]=res;
if(q[i].lca!=0) cal(q[i].lca);
}
for(int i=1;i<=m;++i)
printf("%d\n",ans[i]);
return 0;
}