最近公共祖先 | st 表求 lca 别用欧拉序了!!!

2023.7.17 update

以前写的有点乱,现在简单重写一下核心思想。

第一个思想,我们只要求出 dfn 相邻节点 lca,然后两个点 lca 肯定在他们 dfn 这个区间里相邻节点 lca 中。

然后有个简化代码的发现,dfn 相邻两个节点的 lca,其实是后者的父亲。

然后可以看代码,注意要特判 lca 两个节点相等。

原文

钱菜鸡水平不行,只能写写最近公共祖先了。

目前 OI 所流行的 \(O(nlogn) - O(1)\) 的 LCA 算法是 欧拉序 + RMQ,显然,欧拉序没有这么好写,而且常数不小(序列长度两倍),所以导致了很多情况下更多人选择了倍增等算法。

欧拉序+RMQ 算法中,我们要实现一个 \(2n\) 长度的序列的 RMQ,但是我们发现,我们询问的端点数量却最多只有 \(n\) 个,这意味这我们可以将某些连续的段给缩起来。

缩起来后就成了一个序列 \(A\),我们记 \(idfn_i\) 为满足 \(dfn_{idfn_i} = i\) 的一个数,容易发现 \(A_i = lca(idfn_{i}, idfn_{i + 1})\),显然,我们求 \(x, y (dfn_x <= dfn_y)\) 的 lca 只需要求出 \(A_i (dfn_x \leq i \lt dfn_y)\)\(dfn\) 最小的节点 (注意,需要特判 \(x = y\))。

下面我们的问题是如何求 \(A_i\)。可以证明的是 \(A_i = lca(idfn_{i}, idfn_{i + 1}) = fa_{idfn_{i + 1}}\),因此可以在一次 \(dfs\) 内简单的求出 \(A\) 数组。于是我们只用做一次长度为 \(n\) 的 RMQ 预处理,在 \(O(n) - O(1)\) lca 中也有不错的优化效果。

\(O(nlogn) - O(1)\)

#include<bits/stdc++.h>
const int maxn = 500500;
int n, m, s;
struct T{ int to, nxt; } way[maxn << 1];
int h[maxn], num;
int st[20][maxn], dfn[maxn], tot;
inline int min(int x,int y){ return dfn[x] < dfn[y] ? x : y; }
inline void link(int x,int y) {
	way[++num] = {y, h[x]}, h[x] = num;
	way[++num] = {x, h[y]}, h[y] = num;
}
inline void dfs(int x,int fa = 0) {
	st[0][tot] = fa, dfn[x] = ++tot;
	for(int i = h[x];i;i = way[i].nxt) if(way[i].to != fa)
		dfs(way[i].to, x);
}
inline int lca(int x,int y) {
	if(dfn[x] > dfn[y]) std::swap(x, y);
	const int lg = std::__lg(dfn[y] - dfn[x]);
	return x != y ? min(st[lg][dfn[x]], st[lg][dfn[y] - (1 << lg)]) : x;
}
int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0);
	std::cin >> n >> m >> s;
	for(int i = 1,x,y;i < n;++i)
		std::cin >> x >> y, link(x,y);
	dfs(s);
	for(int i = 1;i < 20;++i) for(int j = 1;j + (1 << i) - 1 < n;++j)
		st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
	for(int i = 1,x,y;i <= m;++i) {
		std::cin >> x >> y;
		std::cout << lca(x,y) << '\n';
	}
}

\(O(n) - O(1)\):

#include<bits/stdc++.h>
const int maxn = 1000001;
typedef unsigned u32;
struct istream {
    static const int size = 1 << 25;
    static const u32 b = 0x30303030;
    short map[1 << 16];
    char buf[size], *vin;
    inline istream() {
	for(int i = 0;i < 1 << 16;++i) map[i] = (i >> 12) + (i >> 8 & 15) * 100 + (i >> 4 & 15) * 10 + (i & 15) * 1000;
        fread(buf,1,size,stdin);
        vin = buf - 1;
    }
    inline istream& operator >> (int & x) {
    	x = *++vin & 15, ++ vin;
    	u32*& idx = (u32*&) vin;
	for(;(*idx & b) == b;++idx) x = x * 10000 + map[(*idx ^ *idx >> 12 ^ 13107) & 65535];
	for(;isdigit(*vin);++vin) x = x * 10 + (*vin & 15);
        return * this;
    }
} cin;
struct ostream
{
	static const int size = 1 << 23;
	char buf[size], *vout;
	unsigned map[10000];
	inline ostream()
	{
		for(int i = 0;i < 10000;++i) {
			int p = i;
			map[i] = p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
			map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
		}
		vout = buf + size;
	}
	inline ~ ostream()
	{ fwrite(vout,1,buf + size - vout,stdout); }
	inline ostream& operator << (int x)
	{
		for(;x >= 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
		do *--vout = x % 10 + 48; while(x /= 10);
		return * this;
	}
	inline ostream& operator << (char x)
	{
		*--vout = x;
		return * this;
	}
} cout;
int n,q;
int a[maxn],dfn[maxn],tot;
namespace Rmq{
	int st[15][maxn / 32];
	int pre[maxn],p[maxn],w[maxn];
	inline int min(int x,int y){ return dfn[x] < dfn[y] ? x : y; }
	inline void down(int & x,int y){ if(dfn[x] > dfn[y]) x = y; }
	inline int qry(int l,int r){
		const int lg = std::__lg(r - l);
		return l >= r ? 0 : min(st[lg][l],st[lg][r - (1 << lg)]);
	}
	inline int rmq(int l,int r){
		if(l >> 5 == r >> 5) return p[l + __builtin_ctz(w[r] >> l)];
		else return min(qry((l >> 5) + 1,r >> 5),min(a[l],pre[r]));
	}
	inline void build(int n){
		++ (n |= 31);
		memcpy(p,a,n <<2);
		for(int i=0;i<n;i+=32){
			static int st[33];
			pre[i] = a[i];
			int * top = st + 1,s = 1; w[*top = i] = s;
			for(int j=i+1;j<i+32;++j){
				for(;top != st && dfn[a[j]] < dfn[a[*top]];--top) s ^= 1 << *top;
				w[j] = s |= 1 << j; *++top = j; pre[j] = a[st[1]];
			}
			for(int j=i + 30;j >= i;--j) down(a[j],a[j+1]);
			Rmq::st[0][i >> 5] = a[i];
		}
		for(int i = 1;i < 15;++i)
			for(int j = 0;j + (1 << i) - 1 <= n / 32;++j)
				st[i][j] = min(st[i - 1][j],st[i - 1][j + (1 << i - 1)]);
	}
}
struct T{ int to,nxt; } way[maxn << 1];
int h[maxn],num;
inline void adde(int x,int y){
	way[++num] = {y,h[x]}, h[x]=num;
	way[++num] = {x,h[y]}, h[y]=num;
}
inline void dfs(int x,int f){
	a[tot] = f; dfn[x] = ++tot;
	for(int i=h[x];i;i=way[i].nxt) if(way[i].to != f)
		dfs(way[i].to,x);
}
inline int lca(int x,int y){
	if(dfn[x] > dfn[y]) std::swap(x,y);
	return x == y ? x : Rmq::rmq(dfn[x],dfn[y]-1);
}
int ans[maxn];

int main(){
	cin >> n >> q;
	for(int i = 1,x,y;i < n;++i) cin >> x >> y, adde(x,y);
	dfs(1,0); *dfn = 1e9; Rmq::build(n-1);
	for(int i = 1,x,y;i <= q;++i) cin >> x >> y, ans[i] = lca(x,y);
	for(int i = q;i >= 1;--i) cout << '\n' << ans[i];
}

posted @ 2020-01-29 13:55  skip2004  阅读(5605)  评论(4编辑  收藏  举报