倍增 LCA && ST表
凡凡题解
ST表
\(f[i][j]\) 表示从 \(i\) 结点向上走到的第 \(2^j\) 个结点
\(f[i][0] = father[i]\),\(f[i][j] = f[f[i][j-1]][j-1]\) (先走 \(2^{j-1}\) 步,再走 \(2^{j-1}\) 步)
void dfs(int x, int fa, int d) //通过dfs预处理所有结点的父结点与深度
{
f[x][0] = fa; //x的父结点即为f[x][0]
depth[x] = d; //得到长度
for(unsigned int i = 0; i < G[x].size(); i++)
if(G[x][i] != fa)
dfs(G[x][i], x, d + 1);
}
dfs(root, 0, 0);
for(int j = 1; j < MAX_LOG_V; j++) //先循环步数
{
for(int i = 1; i <= n; i++) //再循环结点
if(f[i][j-1]) //如果f[i][j-1]存在
f[i][j] = f[f[i][j-1]][j-1];
}
注意循环时要先循环步数,再循环结点,保证第二维是 \(j-1\) 的已经全部求出,也就是把每个结点都先跳一步,再都跳一步,而第二步可能基于上一次跳跃后的某一步,所以要把上一次的全部处理出来
LCA:最近公共祖先
首先第一步还是让 \(u,v\) 走到同一深度,由于 st 表的第二维是指数级别的增长,因此相比于线性时间的一步步地上升,我们可以用对数时间迅速地上升至同一高度。
在上升到同一高度之后,原来的方法是两个结点同时一步步地上升直到到达同一个结点,而我们现在可以采取二分的方式。比如两个结点距离 LCA 的距离为 15, 若是同时上升 16,则会超过 LCA,于是不上升;同时上升 8,距离变成 7,再上升4,距离变成3,再上升2,距离变成1,最后一起上升1,到达 LCA。所以这里步数是从大到小进行遍历,这个处理技巧是非常关键的:
int lca(int u, int v)
{
if(depth[v] > depth[u]) swap(u, v); //让u的深度更大
for(int i = 0; i <= MAX_LOG_V; i++) //让u,v到达统一深度
if((depth[u] - depth[v]) >> i & 1)
u = f[u][i];
if(u == v) return u;
for(int k = MAX_LOG_V; k >= 0; k--) //利用二分(st表)来计算LCA
{
if(f[u][k] != f[v][k]) //若是不会超出LCA就上升
{
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
这样通过预处理 \(O(nlogn)\),每次查询 LCA 的时间就可以变成 \(O(logn)\) 了, 所以这里面倍增的思想很关键。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1e4 + 7;
const int maxm = 3033;
const int mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int P = 131;
const double pi = acos(-1.0);
vector<int>v[maxn];
int fa[maxn], depth[maxn], f[maxn][15], n, root;
bool mark[maxn];
void dfs(int x, int fa, int d)
{
f[x][0] = fa;
depth[x] = d;
for(int i = 0; i < v[x].size(); ++i)
if(v[x][i] != fa) dfs(v[x][i], x, d + 1);
}
int lca(int u, int v)
{
if(depth[v] > depth[u]) swap(u, v);
for(int i = 0; i < 15; ++i)
if((depth[u] - depth[v]) >> i & 1)
u = f[u][i];
if(u == v) return u;
for(int k = 14; k >= 0; --k)
{
if(f[u][k] != f[v][k])
{
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
int main()
{
int t, x, y;
cin >> t;
while(t--)
{
for(int i = 1; i <= n; ++i) v[i].clear();
memset(mark, 0, sizeof(mark));
memset(f, 0, sizeof(f));
scanf("%d", &n);
for(int i = 1; i < n; ++i)
{
scanf("%d %d", &x, &y);
v[x].push_back(y);
//v[y].push_back(x); 不需要双向建边
mark[y] = 1;
}
scanf("%d %d", &x, &y);
for(int i = 1; i <= n; ++i)
{
if(!mark[i])
{
root = i;
break;
}
}
dfs(root, 0, 0);
for(int j = 1; j < 15; ++j)
{
for(int i = 1; i <= n; ++i)
if(f[i][j-1])
f[i][j] = f[f[i][j-1]][j-1];
}
printf("%d\n", lca(x, y));
}
}