1. 程式人生 > >【LCA模板】

【LCA模板】

描述

給定一棵 n 個點的樹,Q個詢問,每次詢問點 x 到點 y兩點之間的距離。

輸入

第一行一個正整數 n ,表示這棵樹有 n個節點;

接下來 n−1行,每行兩個整數 x,y表示 x,y之間有一條連邊;

然後一個整數 Q,表示有Q個詢問;

接下來Q 行每行兩個整數x,y 表示詢問 x 到 y 的距離。

輸出

輸出 Q 行,每行表示每個詢問的答案。

樣例輸入[複製]

6
1 2
1 3
2 4
2 5
3 6
2
2 6
5 6

樣例輸出[複製]

3
4

這裡說三個方法:

①RMQ求LCA是O(n logn) 預處理然後O(1)求lca,空間也要帶一個log。雖然查詢O(1)很爽,但是如果n太大的話有可能預處理就炸了,也有可能空間不夠。

主要流程是:

先dfs整個樹,得到三個序列——尤拉序列【t】,深度序列【dep】 和 結點第一次出現的時間序列【pos】

【尤拉序列】:dfs到某個點時記一次,回溯到這個點時也記一次。那麼顯然長度就是2n【n為點數】

對這個尤拉序列做一個RMQ,記錄區間深度最小值【dp陣列存的是尤拉序列的下標,dep陣列存的是尤拉序列對應點的深度】

然後查詢u和v的lca時,就找到在dep[pos[u]]~dep[pos[v]]【如果pos[u]>pos[v]就把u和v交換一下】中深度最小的那個點就行了。

程式碼:

#include<bits/stdc++.h>
using namespace std;
const int maxn=101000;
const int maxm=202000;
int cnt,tot,N,u,v,Q,x,y;
int Head[maxn],Next[maxn],V[maxn];
int dist[maxn],dis[maxn],pos[maxn],vis[maxn];
int dp[maxm][20],dep[maxm],t[maxm];

void read(int &x){
	x=0;char ch=getchar();
	while(ch>'9'||ch<'0') ch=getchar();
	while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
}

void Add(int u,int v){
	++cnt;
	Next[cnt]=Head[u];
	V[cnt]=v;
	Head[u]=cnt;
}

void dfs(int u,int de){
	int i,v,w;
	if(!vis[u]) vis[u]=1,pos[u]=tot;
	dep[tot]=de,t[tot]=u,tot++;
	for(i=Head[u];i!=-1;i=Next[i]){
		v=V[i];
		if(vis[v]) continue;
		dfs(v,de+1);
		dep[tot]=de,t[tot]=u,tot++;
	}
}

void RMQ(){
	for(int j=0;(1<<j)<tot;j++){
		for(int i=0;i+(1<<j) <tot;i++){
			if(j==0) dp[i][j]=i;
			else{
				if(dep[dp[i][j-1]]<dep[dp[i+(1<<(j-1))][j-1]])
					dp[i][j]=dp[i][j-1];
				else
					dp[i][j]=dp[i+(1<<(j-1))][j-1];
			}
		}
	}
}
int Query(int p1,int p2){
	int k=log2(p2-p1+1);
	if(dep[dp[p1][k]]<dep[dp[p2-(1<<k)+1][k]])
		return t[dp[p1][k]];
	return t[dp[p2-(1<<k)+1][k]];
}

int lca(int v1,int v2){
	if(pos[v1]>pos[v2]) return Query(pos[v2],pos[v1]);
	return Query(pos[v1],pos[v2]);
}
void init(){
	cnt=tot=0;
	memset(Head,-1,sizeof(Head));
	memset(Next,-1,sizeof(Next));
}
int main(){
	init();
	read(N);
	for(int i=1;i<N;++i) read(u),read(v),Add(u,v);
	read(Q);
	dfs(1,0),RMQ();
	for(int op=1;op<=Q;++op){
		read(x),read(y);
		printf("%d\n",dep[pos[x]]+dep[pos[y]]-2*dep[pos[lca(x,y)]]);
	}
}

②倍增求LCA是O((n+q)logn),查詢是logn的。寫起來比較簡單,但是常數比較大,可能會莫名其妙地掛掉。

大概就是先dfs一下求一下深度然後瞎搞。

#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+10;
int n,m,u,v,cnt=0;
int Next[maxn<<1],V[maxn<<1],Head[maxn],dep[maxn],f[maxn][20];
void read(int &x){
    x=0;char ch=getchar();
    while(ch>'9'||ch<'0') ch=getchar();
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
}
void add(int u,int v){
    ++cnt;
    Next[cnt]=Head[u];
    V[cnt]=v;
    Head[u]=cnt;
}
void dfs(int x,int fa){
    f[x][0]=fa;
    for(int i=1;i<19;++i)
        f[x][i]=f[f[x][i-1]][i-1];
    dep[x]=dep[fa]+1;
    for(int i=Head[x];i;i=Next[i])
        if(V[i]!=fa)
            dfs(V[i],x);
}
int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=18;i>=0;i--)
        if(dep[f[x][i]]>=dep[y])
            x=f[x][i];
    if(x==y) return x;
    for(int i=18;i>=0;i--)
        if(f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
int main(){
    read(n);
    for(int i=1;i<n;++i){
        read(u),read(v);
        add(u,v),add(v,u);
    }
    dfs(1,0),read(m);
    for(int i=1;i<=m;++i){
        read(u),read(v);
        int g=lca(u,v);
        printf("%d",dep[u]+dep[v]-2*dep[g]);
        putchar('\n');
    }
}

③樹鏈剖分先dfs兩次,也就是O(n)預處理,查詢也是logn【可以證明重鏈只有logn條】的,比倍增要快一些。而且空間是O(n)的,比倍增空間複雜度低。求lca的時候就是:只要x和y不在同一條重鏈上,就把更深的點往上跳。

程式碼:

#include<bits/stdc++.h>
using namespace std;
const int maxn=100010;
int Head[maxn],Next[maxn<<1],V[maxn<<1];
int depth[maxn],son[maxn],fa[maxn],siz[maxn],top[maxn];
int n,q,x,y,cnt=0;
void add(int u,int v){
	++cnt;
	Next[cnt]=Head[u];
	V[cnt]=v;
	Head[u]=cnt;
}
void addedge(int u,int v){add(u,v),add(v,u);}
void dfs1(int u,int f){
	siz[u]=1,son[u]=0,fa[u]=f;
	for(int i=Head[u];i;i=Next[i]){
		int v=V[i];
		if(v==f) continue;
		depth[v]=depth[u]+1;
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void dfs2(int u,int fa){
	top[u]=(u==son[fa])?top[fa]:u;
	for(int i=Head[u];i;i=Next[i]){
		int v=V[i];
		if(v==fa) continue;
		dfs2(v,u);
	}
}

int lca(int x,int y){
	for(;top[x]!=top[y];depth[top[x]]>depth[top[y]]?x=fa[top[x]]:y=fa[top[y]]);
	return depth[x]<depth[y]?x:y;
}

int main(){
	scanf("%d",&n);
	for(int i=1;i<n;++i)
		scanf("%d%d",&x,&y),addedge(x,y);
	dfs1(1,0),dfs2(1,0);
	scanf("%d",&q);
	while(q--){
		scanf("%d%d",&x,&y);
		printf("%d\n",depth[x]+depth[y]-2*depth[lca(x,y)]);
	}
}