LCA
倍增法求LCA
讲解见此:https://www.luogu.com.cn/blog/morslin/solution-p3379
P3379
#include<bits/stdc++.h>
using namespace std;
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
int n, m, s;
vector<int> g[500010];
int lg2[500010];
int dep[500010], anc[500010][25];
void dfs(int now, int fa) {
dep[now] = dep[fa] + 1; anc[now][0] = fa;
f(i, 1, lg2[dep[now]]) {
anc[now][i] = anc[anc[now][i - 1]][i - 1];
}
f(i, 0, g[now].size() -1 ) {
if(g[now][i] != fa) dfs(g[now][i], now);
}
}
int lca(int qx, int qy) {
if(dep[qx] < dep[qy]) swap(qx, qy);
while(dep[qx] > dep[qy]) {
qx = anc[qx][lg2[dep[qx]-dep[qy]]];
}
if(qx == qy) return qx;
for(int k = lg2[dep[qx]]; k >= 0; k--) {
if(anc[qx][k] != anc[qy][k]) {
qx = anc[qx][k]; qy = anc[qy][k];
}
}
return anc[qx][0];
}
int main() {
ios::sync_with_stdio(0);
cin.tie(NULL);
cout.tie(NULL);
cin >> n >> m >> s;
f(i, 1, n) lg2[i] = lg2[i - 1] + (1 << lg2[i - 1] == i);
f(i, 1, n) lg2[i]--;
f(i, 1, n - 1) {
int x, y; cin >> x >> y;
g[x].push_back(y); g[y].push_back(x);
}
dfs(s, 0);
f(i, 1, m) {
int qx, qy; cin >> qx >> qy;
cout << lca(qx, qy) << endl;
}
return 0;
}
CF1702G
LCA 好题。
先理解题意:求给定的一个点集是否包含在一条链中。
考虑放在以 \(1\) 为根的树中判断。分析性质:在给定的序列中如果能组成一条链,那么一定不存在一个节点的度为 \(3\) 以上。
考虑放在有根树上。为了方便,下文说的“树”表示询问序列涉及到的树。
我们先求出这棵树的根,方法是先按照 DFS 序排序(这样的目的是同一个子树下的节点总是放在一起),再对每对相邻的两个节点求 LCA,对这些 LCA 取深度最大值就是这个链的根节点。
为什么这个是正确的呢?对于根节点,它的相邻子树间一定存在 DFS 序相邻的两个点,那么一定能找到这个根,不会算漏。
分析不能成链的性质:根有三个以上子树,或者不是根的节点有两个以上子树。我们还是用 DFS 序排序过的节点,对于相邻的两个求 LCA。结果有如下两种:
- LCA 是这两个节点中的一个。说明一个节点是另一个节点的祖先。
- LCA 不是这两个节点中的一个。说明这两个节点分别在 LCA 的两个分叉上。
第二种很重要,说明找到了一个分叉。判断如果这个 LCA 就是根,那么给它两次机会(因为根可以有两个分叉)。否则直接判 NO。
时间复杂度 \(O(q \log q)\)。
#include<bits/stdc++.h>
using namespace std;
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
vector<int> g[200010];
int lg2[200010], dfn[200010];
int dep[200010], anc[200010][25];
int cnt = 0;
void dfs(int now, int fa) {
dfn[now] = ++cnt;
dep[now] = dep[fa] + 1; anc[now][0] = fa;
f(i, 1, lg2[dep[now]]) {
anc[now][i] = anc[anc[now][i - 1]][i - 1];
}
f(i, 0, (int)g[now].size() -1 ) {
if(g[now][i] != fa) dfs(g[now][i], now);
}
}
int lca(int qx, int qy) {
if(dep[qx] < dep[qy]) swap(qx, qy);
while(dep[qx] > dep[qy]) {
qx = anc[qx][lg2[dep[qx]-dep[qy]]];
}
if(qx == qy) return qx;
for(int k = lg2[dep[qx]]; k >= 0; k--) {
if(anc[qx][k] != anc[qy][k]) {
qx = anc[qx][k]; qy = anc[qy][k];
}
}
return anc[qx][0];
}
int ask[200010];
bool cmp(int x, int y) {return dfn[x] < dfn[y];}
int main() {
ios::sync_with_stdio(0);
cin.tie(NULL);
cout.tie(NULL);
int n; cin >> n;
f(i, 1, n) lg2[i] = lg2[i - 1] + (1 << lg2[i - 1] == i);
f(i, 1, n) lg2[i]--;
f(i, 1, n - 1) {
int x, y; cin >> x >> y;
g[x].push_back(y); g[y].push_back(x);
}
dfs(1, 0);
int q; cin >> q;
f(i, 1, q) {
int nn; cin >> nn;
f(j, 1, nn) cin >> ask[j];
if(nn == 1) {cout << "YES\n"; continue;}
sort(ask + 1, ask + nn + 1, cmp);
int root = lca(ask[1], ask[2]);
f(j, 1, nn - 1) {
int k = lca(ask[j], ask[j + 1]);
if(dep[k] < dep[root]) root = k;
}
int mmm = 0;
f(j, 1, nn - 1) {
int k = lca(ask[j], ask[j + 1]);
if(k != ask[j] && k != ask[j + 1]) {
if(k != root) mmm += 2;
else mmm++;
}
}
if(mmm >= 2) {cout << "NO\n"; continue;}
else cout << "YES\n";
}
return 0;
}
\(O(n \log n) -O(1)\) LCA
在欧拉序(记录每次返回的时候的状态)上做 RMQ。
预处理常数是 \(2\),但是查询少了个 \(\log\)。
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
//#define cerr if(false)cerr
//#define freopen if(false)freopen
#define watch(x) cerr << (#x) << ' '<<'i'<<'s'<<' ' << x << endl
void pofe(int number, int bitnum) {
string s; f(i, 0, bitnum) {s += char(number & 1) + '0'; number >>= 1; }
reverse(s.begin(), s.end()); cerr << s << endl;
return;
}
void cmax(int &x, int y) {if(x < y) x = y;}
void cmin(int &x, int y) {if(x > y) x = y;}
//调不出来给我对拍!
vector<int> t[500010];
struct node {int x, d; bool operator< (const node op) const{return d<op.d;} }e[1000010]; //欧拉序,长度为 2 m
struct st_table {
node a[1000010][25];
int n;
void build() {
f(i,1,n)f(j,0,24)a[i][j]={inf, inf};
f(i,1,n) a[i][0]=e[i];
for(int i=n;i>=1;i--){
int k=log2(n-i+1);
f(j,1,k)a[i][j]=min(a[i][j-1],a[i+(1<<(j-1))][j-1]);
}
}
node query(int l,int r){
int k=log2(r-l+1);
return min(a[l][k], a[r-(1<<k)+1][k]);
}
}st;
int ecnt = 0; int fir[500010], lst[500010], dep[500010];
void dfs(int now,int fa) {
dep[now]=dep[fa]+1;
e[++ecnt]={now,dep[now]}; cmax(lst[now],ecnt);cmin(fir[now],ecnt);
for(int i:t[now]){
if(i==fa)continue;
dfs(i,now);
e[++ecnt]={now,dep[now]}; cmax(lst[now],ecnt);cmin(fir[now],ecnt);
}
}
int lca(int l,int r){
if(l==r)return l;
if(fir[l] > fir[r]) swap(l, r);
if(lst[l] > lst[r]) return l;
return st.query(fir[l], lst[r]).x;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(NULL);
cout.tie(NULL);
//freopen();
//freopen();
//time_t start = clock();
//think twice,code once.
//think once,debug forever.
int n,m,s;cin>>n>>m>>s; f(i,1,n)fir[i]=inf;
f(i,1,n-1){int u,v;cin>>u>>v;t[u].push_back(v);t[v].push_back(u);}
dfs(s, 0); st.n = ecnt;
st.build();f(i,1,m){int u,v;cin>>u>>v; cout<<lca(u,v)<<endl;}
//time_t finish = clock();
//cout << "time used:" << (finish-start) * 1.0 / CLOCKS_PER_SEC <<"s"<< endl;
return 0;
}
/*
2023/x/xx
start thinking at h:mm
start coding at h:mm
finish debugging at h:mm
*/
DFS 序求 lca
考虑 \(x, y\) 互相不是祖先关系的时候,他们的 lca 不在他们的 dfn 之间。但是 \(y\) 的 lsa 是在的。所以考虑 \(rnk_x, rnk_y\) 区间范围内的 \(dep\) 最小点的父亲。
考虑 \(x\) 是 \(y\) 祖先的时候。这时候 \(rnk_x + 1, rnk_y\) 这个区间仍然满足条件。
所以我们可以对 dfn 序下手,考虑 \(rnk_x+1, rnk_y\) 这个区间内求 rmq,找 \(dep\) 最小的那个数。然后唯一一个需要特判的地方是 \(x=y\)。
#include<bits/stdc++.h>
using namespace std;
#define int long long
//use ll instead of int.
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const int inf = 1e9;
//#define cerr if(false)cerr
//#define freopen if(false)freopen
#define watch(x) cerr << (#x) << ' '<<'i'<<'s'<<' ' << x << endl
void pofe(int number, int bitnum) {
string s; f(i, 0, bitnum) {s += char(number & 1) + '0'; number >>= 1; }
reverse(s.begin(), s.end()); cerr << s << endl;
return;
}
template <typename TYP> void cmax(TYP &x, TYP y) {if(x < y) x = y;}
template <typename TYP> void cmin(TYP &x, TYP y) {if(x > y) x = y;}
//调不出来给我对拍!
//use std::array.
vector<int> g[500200];
int dep[500200], dfn[500200], rnk[500200], st[500200][25];
int dcnt=0;int lg2[500200]; int fa[500200];
void dfs(int x,int k) {
dep[x]=dep[k]+1; fa[x]=k; rnk[x]=++dcnt; dfn[dcnt] = x;
for(int i:g[x])if(i!=k)dfs(i,x);
}
int query(int l,int r){
int k=lg2[r-l+1];
if(dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]])return st[l][k];
else return st[r-(1<<k)+1][k];
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(NULL);
cout.tie(NULL);
//freopen();
//freopen();
//time_t start = clock();
//think twice,code once.
//think once,debug forever.
int n,m,s;cin>>n>>m>>s;
f(i,2,n)lg2[i]=lg2[i>>1]+1;
// f(i,1,n)cout<<lg2[i]<<" ";
// cout<<endl;
f(i,1,n-1){
int x,y;cin>>x>>y;
g[x].push_back(y);g[y].push_back(x);
}
dfs(s,0);
// f(i,1,n)cout<<rnk[i]<<" ";
// cout<<endl;
for(int i=n;i>=1;i--){
st[i][0]=dfn[i];
// cout << st[i][0] << " ";
for(int k=1;k<=lg2[n-i+1];k++) {
if(dep[st[i][k-1]]<=dep[st[i+(1<<(k-1))][k-1]])st[i][k]=st[i][k-1];
else st[i][k]=st[i+(1<<(k-1))][k-1];
// cout << st[i][k] << " ";
}
// cout<<endl;
}
f(i,1,m){
int a,b;cin>>a>>b;if(a==b)cout<<a<<endl;
else {
if(rnk[a] > rnk[b]) swap(a,b);
cout<<fa[query(rnk[a]+1, rnk[b])]<<endl;
}
}
//time_t finish = clock();
//cout << "time used:" << (finish-start) * 1.0 / CLOCKS_PER_SEC <<"s"<< endl;
return 0;
}
/*
2023/x/xx
start thinking at h:mm
start coding at h:mm
finish debugging at h:mm
*/
上次 u 群听说 log2
很慢,改一下,改成预处理,然后记住 \(\log_2 1 = 0\),所以是这么写的:
f(i,2,n)lg2[i]=lg2[i>>1]+1;