倍增算法
倍增算法的主要思想是通过引入一个二维数组f[v][i],表示节点v的第个祖先。
该段为核心代码
f[u][0] = fa;
dep[u] = dep[fa] + 1;
for (int k = 1; (1 << k) <= dep[u] - 1; k++) {
f[u][k] = f[f[u][k - 1]][k - 1];
}
其中根节点的深度dep[root]=1。
f[u][0]是顶点u的父亲节点fa,对于顶点u来说,其第个祖先的深度dep为dep[u]-(1<<i),因此终止条件为(1 << i) <= dep[u] - 1,但其他人的代码里有的是用dep[u]代替dep[u] - 1,一样可以过。使用
dep[v]
或许更符合深度的定义,并且通常不会出现需要减去 1 的情况。深度 dep[u]
本身是从根节点到 v
的路径长度,而 步的祖先实际上并不一定要严格从根节点计算。因此,直接用
dep[u]
作为循环条件可以避免很多冗余的减法,我是这么理解的。
LCA算法
int lca(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int i = 21; i >= 0; i--) {
if (dep[x] <= dep[y] - (1 << i)) {
y = f[y][i];
}
}
if (x == y) return x;
for (int i = 21; i >= 0; i--) {
if (f[x][i] == f[y][i]) continue;
else {
x = f[x][i]; y = f[y][i];
}
}
return f[y][0];
}
让y作为深层顶点,依次由大到小查找x的祖先节点,其过程就像将一个十进制数由大到小转化为二进制数一样 。
第一轮循环后,x与y在同一深度,之后进行特判x == y ;
第二轮循环是查找x与y的公共祖先,从大到小,假设x,y的公共祖先差的深度为H,则该过程就像将H用二进制表示一样,如果f[x][i] == f[y][i],则该点置为0,否则置为1(就是跳到处的祖先)。
另一个问题是邻接表的建立
void add(int u, int v) {
e[cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt++;
}
head[u]表示顶点u的首条边,初始化为-1,根据e[head[u]],to表示u的邻居顶点,next表示下一条边,形成一个邻接链表。或者用一个vector构建邻接矩阵也行。
以下就是本题的详细代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxsize = 5e5 + 10;
struct edge {
int to, next;
}e[maxsize * 2];
int head[maxsize];
int f[maxsize][22];
int dep[maxsize];
int cnt = 0;
int n, m, s;
void add(int u, int v) {
e[cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt++;
}
void dfs(int u, int fa) {
f[u][0] = fa;
dep[u] = dep[fa] + 1;
for (int k = 1; (1 << k) <= dep[u] - 1; k++) {
f[u][k] = f[f[u][k - 1]][k - 1];
}
for (int i = head[u]; i != -1; i = e[i].next) {
int w = e[i].to;
if (w != fa) dfs(w, u);
}
}
int lca(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int i = 21; i >= 0; i--) {
if (dep[x] <= dep[y] - (1 << i)) {
y = f[y][i];
}
}
if (x == y) return x;
for (int i = 21; i >= 0; i--) {
if (f[x][i] == f[y][i]) continue;
else {
x = f[x][i]; y = f[y][i];
}
}
return f[y][0];
}
int main()
{
scanf("%d %d %d", &n, &m, &s);
memset(head, -1, sizeof(head));
memset(f, 0, sizeof(f));
int u, v;
for (int i = 1; i < n; i++) {
scanf("%d %d", &u, &v);
add(u, v); add(v, u);
}
dfs(s, 0);
while (m--) {
scanf("%d %d", &u, &v);
printf("%d\n", lca(u, v));
}
return 0;
}