1. 程式人生 > >倍增求LCA模板

倍增求LCA模板

1.引入

2.思路

這道題目是倍增求LCA的模板題。

首先,大家都知道LCA的定義吧?(兩個節點的公共父節點)如果我們求兩個點的LCA的使用暴力求解(DFS找出要求點的深度,一個一個往上跳,一次一次查詢),在卡時間的競賽中是肯定會炸掉的。那麼,我們就使用另一種方法,樹上倍增法:

我們設father[x,k] 表示 x 的 2^k 倍祖先,那麼很容易知道,father[x,0]就是當前節點的父親(記住,當前節點可以代表當前深度的所有節點因為它是一顆樹!),father[x,1]就是當前節點的父親的父親也就是fatehr[father[x,0],0] ,  father[x,2]就是當前節點的父親的父親的父親,也就是father[father[father[x,0],0],0]也等於father[father[x,1],1]...(以此類推).那麼,我們就可以計算出當前節點到它所有的父親要走的路(因為它是一棵樹)。片段是這樣的:

inline void dfs(int now,int fath)
{
	depht[now]=depht[fath]+1;
	father[now][0]=fath;
	for(register int i=1;(1<<i)<=depht[now];++i)
	  father[now][i]=father[father[now][i-1]][i-1];//求出當前節點到各個祖先節點的距離。
	for(register int i=head[now];i;i=e[i].nex)
	{
		if(e[i].t!=fath)//要求的這一條邊不能通往父親節點
        dfs(e[i].t,now);//求出指向當前節點的子節點到各個祖先節點的距離(有點繞)
	}	
}

基於father陣列我們可以計算LCA了。

我們先設 depth[x] 和 depth[y] 為當前節點的深度,那麼,基於二進位制拆分的思想,把x,y調到同一深度。

之後,我們又運用二進位制拆分的思想,讓他們一起走到同一個點。(嘗試走2^(log(depth[x]-depth[y])(向下取整))步,2^(log(depth[x]-depth[y]-1)(向下取整))步....1步)

不說了,上程式碼:

#include<bits/stdc++.h>
using namespace std;
struct node{
	int t,nex;
}e[500001<<1];
int depht[500001],father[500001][22],lg[500001],head[500001];
int tot;
inline void add(int x,int y)
{
	e[++tot].t=y;
	e[tot].nex=head[x];
	head[x]=tot;
}
inline void dfs(int now,int fath)
{
	depht[now]=depht[fath]+1;
	father[now][0]=fath;
	for(register int i=1;(1<<i)<=depht[now];++i)
	  father[now][i]=father[father[now][i-1]][i-1];
	for(register int i=head[now];i;i=e[i].nex)
	{
		if(e[i].t!=fath)dfs(e[i].t,now);	
	}	
}
inline int lca(int x,int y)
{
	if(depht[x]<depht[y])
	  swap(x,y);
	while(depht[x]>depht[y])
	  x=father[x][lg[depht[x]-depht[y]]-1];
	if(x==y)
	  return x;
	for(register int k=lg[depht[x]];k>=0;--k)
	  if(father[x][k]!=father[y][k])
	    x=father[x][k],y=father[y][k];
	return father[x][0];
}
int n,m,s;
int main()
{
	//freopen("1.txt","r",stdin);
	scanf("%d%d%d",&n,&m,&s);
	for(register int i=1;i<=n-1;++i)
	{
		int x,y;scanf("%d%d",&x,&y);
		add(x,y);add(y,x);
	}
	dfs(s,0);
	for(register int i=1;i<=n;++i)
	  lg[i]=lg[i-1]+(1<<lg[i-1]==i);
	for(register int i=1;i<=m;++i)
	{
		int x,y;scanf("%d%d",&x,&y);
		printf("%d\n",lca(x,y));
	}
	return 0;
}