最近公共祖先 | st 表求 lca 别用欧拉序了!!!
2023.7.17 update
以前写的有点乱,现在简单重写一下核心思想。
第一个思想,我们只要求出 dfn 相邻节点 lca,然后两个点 lca 肯定在他们 dfn 这个区间里相邻节点 lca 中。
然后有个简化代码的发现,dfn 相邻两个节点的 lca,其实是后者的父亲。
然后可以看代码,注意要特判 lca 两个节点相等。
原文
钱菜鸡水平不行,只能写写最近公共祖先了。
目前 OI 所流行的 \(O(nlogn) - O(1)\) 的 LCA 算法是 欧拉序 + RMQ,显然,欧拉序没有这么好写,而且常数不小(序列长度两倍),所以导致了很多情况下更多人选择了倍增等算法。
欧拉序+RMQ 算法中,我们要实现一个 \(2n\) 长度的序列的 RMQ,但是我们发现,我们询问的端点数量却最多只有 \(n\) 个,这意味这我们可以将某些连续的段给缩起来。
缩起来后就成了一个序列 \(A\),我们记 \(idfn_i\) 为满足 \(dfn_{idfn_i} = i\) 的一个数,容易发现 \(A_i = lca(idfn_{i}, idfn_{i + 1})\),显然,我们求 \(x, y (dfn_x <= dfn_y)\) 的 lca 只需要求出 \(A_i (dfn_x \leq i \lt dfn_y)\) 中 \(dfn\) 最小的节点 (注意,需要特判 \(x = y\))。
下面我们的问题是如何求 \(A_i\)。可以证明的是 \(A_i = lca(idfn_{i}, idfn_{i + 1}) = fa_{idfn_{i + 1}}\),因此可以在一次 \(dfs\) 内简单的求出 \(A\) 数组。于是我们只用做一次长度为 \(n\) 的 RMQ 预处理,在 \(O(n) - O(1)\) lca 中也有不错的优化效果。
\(O(nlogn) - O(1)\)
#include<bits/stdc++.h>
const int maxn = 500500;
int n, m, s;
struct T{ int to, nxt; } way[maxn << 1];
int h[maxn], num;
int st[20][maxn], dfn[maxn], tot;
inline int min(int x,int y){ return dfn[x] < dfn[y] ? x : y; }
inline void link(int x,int y) {
way[++num] = {y, h[x]}, h[x] = num;
way[++num] = {x, h[y]}, h[y] = num;
}
inline void dfs(int x,int fa = 0) {
st[0][tot] = fa, dfn[x] = ++tot;
for(int i = h[x];i;i = way[i].nxt) if(way[i].to != fa)
dfs(way[i].to, x);
}
inline int lca(int x,int y) {
if(dfn[x] > dfn[y]) std::swap(x, y);
const int lg = std::__lg(dfn[y] - dfn[x]);
return x != y ? min(st[lg][dfn[x]], st[lg][dfn[y] - (1 << lg)]) : x;
}
int main() {
std::ios::sync_with_stdio(false), std::cin.tie(0);
std::cin >> n >> m >> s;
for(int i = 1,x,y;i < n;++i)
std::cin >> x >> y, link(x,y);
dfs(s);
for(int i = 1;i < 20;++i) for(int j = 1;j + (1 << i) - 1 < n;++j)
st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
for(int i = 1,x,y;i <= m;++i) {
std::cin >> x >> y;
std::cout << lca(x,y) << '\n';
}
}
\(O(n) - O(1)\):
#include<bits/stdc++.h>
const int maxn = 1000001;
typedef unsigned u32;
struct istream {
static const int size = 1 << 25;
static const u32 b = 0x30303030;
short map[1 << 16];
char buf[size], *vin;
inline istream() {
for(int i = 0;i < 1 << 16;++i) map[i] = (i >> 12) + (i >> 8 & 15) * 100 + (i >> 4 & 15) * 10 + (i & 15) * 1000;
fread(buf,1,size,stdin);
vin = buf - 1;
}
inline istream& operator >> (int & x) {
x = *++vin & 15, ++ vin;
u32*& idx = (u32*&) vin;
for(;(*idx & b) == b;++idx) x = x * 10000 + map[(*idx ^ *idx >> 12 ^ 13107) & 65535];
for(;isdigit(*vin);++vin) x = x * 10 + (*vin & 15);
return * this;
}
} cin;
struct ostream
{
static const int size = 1 << 23;
char buf[size], *vout;
unsigned map[10000];
inline ostream()
{
for(int i = 0;i < 10000;++i) {
int p = i;
map[i] = p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
}
vout = buf + size;
}
inline ~ ostream()
{ fwrite(vout,1,buf + size - vout,stdout); }
inline ostream& operator << (int x)
{
for(;x >= 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
do *--vout = x % 10 + 48; while(x /= 10);
return * this;
}
inline ostream& operator << (char x)
{
*--vout = x;
return * this;
}
} cout;
int n,q;
int a[maxn],dfn[maxn],tot;
namespace Rmq{
int st[15][maxn / 32];
int pre[maxn],p[maxn],w[maxn];
inline int min(int x,int y){ return dfn[x] < dfn[y] ? x : y; }
inline void down(int & x,int y){ if(dfn[x] > dfn[y]) x = y; }
inline int qry(int l,int r){
const int lg = std::__lg(r - l);
return l >= r ? 0 : min(st[lg][l],st[lg][r - (1 << lg)]);
}
inline int rmq(int l,int r){
if(l >> 5 == r >> 5) return p[l + __builtin_ctz(w[r] >> l)];
else return min(qry((l >> 5) + 1,r >> 5),min(a[l],pre[r]));
}
inline void build(int n){
++ (n |= 31);
memcpy(p,a,n <<2);
for(int i=0;i<n;i+=32){
static int st[33];
pre[i] = a[i];
int * top = st + 1,s = 1; w[*top = i] = s;
for(int j=i+1;j<i+32;++j){
for(;top != st && dfn[a[j]] < dfn[a[*top]];--top) s ^= 1 << *top;
w[j] = s |= 1 << j; *++top = j; pre[j] = a[st[1]];
}
for(int j=i + 30;j >= i;--j) down(a[j],a[j+1]);
Rmq::st[0][i >> 5] = a[i];
}
for(int i = 1;i < 15;++i)
for(int j = 0;j + (1 << i) - 1 <= n / 32;++j)
st[i][j] = min(st[i - 1][j],st[i - 1][j + (1 << i - 1)]);
}
}
struct T{ int to,nxt; } way[maxn << 1];
int h[maxn],num;
inline void adde(int x,int y){
way[++num] = {y,h[x]}, h[x]=num;
way[++num] = {x,h[y]}, h[y]=num;
}
inline void dfs(int x,int f){
a[tot] = f; dfn[x] = ++tot;
for(int i=h[x];i;i=way[i].nxt) if(way[i].to != f)
dfs(way[i].to,x);
}
inline int lca(int x,int y){
if(dfn[x] > dfn[y]) std::swap(x,y);
return x == y ? x : Rmq::rmq(dfn[x],dfn[y]-1);
}
int ans[maxn];
int main(){
cin >> n >> q;
for(int i = 1,x,y;i < n;++i) cin >> x >> y, adde(x,y);
dfs(1,0); *dfn = 1e9; Rmq::build(n-1);
for(int i = 1,x,y;i <= q;++i) cin >> x >> y, ans[i] = lca(x,y);
for(int i = q;i >= 1;--i) cout << '\n' << ans[i];
}