LCA

倍增法求LCA

讲解见此:https://www.luogu.com.cn/blog/morslin/solution-p3379

P3379

#include<bits/stdc++.h>
using namespace std;
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
int n, m, s;
vector<int> g[500010];
int lg2[500010];
int dep[500010], anc[500010][25];
void dfs(int now, int fa) {
    dep[now] = dep[fa] + 1; anc[now][0] = fa;
    f(i, 1, lg2[dep[now]]) {
        anc[now][i] = anc[anc[now][i - 1]][i - 1];
    }
    f(i, 0, g[now].size() -1 ) {
        if(g[now][i] != fa) dfs(g[now][i], now);
    }
}
int lca(int qx, int qy) {
    if(dep[qx] < dep[qy]) swap(qx, qy);
    while(dep[qx] > dep[qy]) {
        qx = anc[qx][lg2[dep[qx]-dep[qy]]];
    }
    if(qx == qy) return qx;
    for(int k = lg2[dep[qx]]; k >= 0; k--) {
        if(anc[qx][k] != anc[qy][k]) {
            qx = anc[qx][k]; qy = anc[qy][k];
        }
    }
    return anc[qx][0];
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    cin >> n >> m >> s;
    f(i, 1, n) lg2[i] = lg2[i - 1] + (1 << lg2[i - 1] == i);
    f(i, 1, n) lg2[i]--;
    f(i, 1, n - 1) {
        int x, y; cin >> x >> y;
        g[x].push_back(y); g[y].push_back(x);
    }
    dfs(s, 0);
    f(i, 1, m) {
        int qx, qy; cin >> qx >> qy;
        cout << lca(qx, qy) << endl;
    }
    return 0;
}

CF1702G

LCA 好题。
先理解题意:求给定的一个点集是否包含在一条链中。

考虑放在以 \(1\) 为根的树中判断。分析性质:在给定的序列中如果能组成一条链,那么一定不存在一个节点的度为 \(3\) 以上。

考虑放在有根树上。为了方便,下文说的“树”表示询问序列涉及到的树。

我们先求出这棵树的根,方法是先按照 DFS 序排序(这样的目的是同一个子树下的节点总是放在一起),再对每对相邻的两个节点求 LCA,对这些 LCA 取深度最大值就是这个链的根节点。
为什么这个是正确的呢?对于根节点,它的相邻子树间一定存在 DFS 序相邻的两个点,那么一定能找到这个根,不会算漏。

分析不能成链的性质:根有三个以上子树,或者不是根的节点有两个以上子树。我们还是用 DFS 序排序过的节点,对于相邻的两个求 LCA。结果有如下两种:

  1. LCA 是这两个节点中的一个。说明一个节点是另一个节点的祖先。
  2. LCA 不是这两个节点中的一个。说明这两个节点分别在 LCA 的两个分叉上。

第二种很重要,说明找到了一个分叉。判断如果这个 LCA 就是根,那么给它两次机会(因为根可以有两个分叉)。否则直接判 NO。

时间复杂度 \(O(q \log q)\)

#include<bits/stdc++.h>
using namespace std;
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
vector<int> g[200010];
int lg2[200010], dfn[200010];
int dep[200010], anc[200010][25];
int cnt = 0;
void dfs(int now, int fa) {
    dfn[now] = ++cnt;
    dep[now] = dep[fa] + 1; anc[now][0] = fa;
    f(i, 1, lg2[dep[now]]) {
        anc[now][i] = anc[anc[now][i - 1]][i - 1];
    }
    f(i, 0, (int)g[now].size() -1 ) {
        if(g[now][i] != fa) dfs(g[now][i], now);
    }
}
int lca(int qx, int qy) {
    if(dep[qx] < dep[qy]) swap(qx, qy);
    while(dep[qx] > dep[qy]) {
        qx = anc[qx][lg2[dep[qx]-dep[qy]]];
    }
    if(qx == qy) return qx;
    for(int k = lg2[dep[qx]]; k >= 0; k--) {
        if(anc[qx][k] != anc[qy][k]) {
            qx = anc[qx][k]; qy = anc[qy][k];
        }
    }
    return anc[qx][0];
}
int ask[200010];
bool cmp(int x, int y) {return dfn[x] < dfn[y];}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    int n; cin >> n;
    f(i, 1, n) lg2[i] = lg2[i - 1] + (1 << lg2[i - 1] == i);
    f(i, 1, n) lg2[i]--;
    f(i, 1, n - 1) {
        int x, y; cin >> x >> y;
        g[x].push_back(y); g[y].push_back(x);
    }
    dfs(1, 0);
    int q; cin >> q;
    f(i, 1, q) {
        int nn; cin >> nn;
        f(j, 1, nn) cin >> ask[j];
        if(nn == 1) {cout << "YES\n"; continue;}
        sort(ask + 1, ask + nn + 1, cmp);
        int root = lca(ask[1], ask[2]);
        f(j, 1, nn - 1) {
            int k = lca(ask[j], ask[j + 1]);
            if(dep[k] < dep[root]) root = k;
        }
        int mmm = 0;
        f(j, 1, nn - 1) {
            int k = lca(ask[j], ask[j + 1]);
            if(k != ask[j] && k != ask[j + 1]) {
                if(k != root) mmm += 2;
                else mmm++;
            }
        }
        if(mmm >= 2) {cout << "NO\n"; continue;}
        else cout << "YES\n";
    }
    return 0;
}

\(O(n \log n) -O(1)\) LCA

在欧拉序(记录每次返回的时候的状态)上做 RMQ。

预处理常数是 \(2\),但是查询少了个 \(\log\)

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
//#define cerr if(false)cerr
//#define freopen if(false)freopen
#define watch(x) cerr  << (#x) << ' '<<'i'<<'s'<<' ' << x << endl
void pofe(int number, int bitnum) {
    string s; f(i, 0, bitnum) {s += char(number & 1) + '0'; number >>= 1; } 
    reverse(s.begin(), s.end()); cerr << s << endl; 
    return;
}
void cmax(int &x, int y) {if(x < y) x = y;}
void cmin(int &x, int y) {if(x > y) x = y;}
//调不出来给我对拍!
vector<int> t[500010];
struct node {int x, d; bool operator< (const node op) const{return d<op.d;} }e[1000010]; //欧拉序,长度为 2 m
struct st_table {
    node a[1000010][25];
    int n;
    void build() { 
        f(i,1,n)f(j,0,24)a[i][j]={inf, inf};
        f(i,1,n) a[i][0]=e[i];
        for(int i=n;i>=1;i--){
            int k=log2(n-i+1); 
            f(j,1,k)a[i][j]=min(a[i][j-1],a[i+(1<<(j-1))][j-1]);
        }
    }
    node query(int l,int r){
        int k=log2(r-l+1);
        return min(a[l][k], a[r-(1<<k)+1][k]);
    }
}st;
int ecnt = 0; int fir[500010], lst[500010], dep[500010];
void dfs(int now,int fa) {
    dep[now]=dep[fa]+1;
    e[++ecnt]={now,dep[now]}; cmax(lst[now],ecnt);cmin(fir[now],ecnt);
    for(int i:t[now]){
        if(i==fa)continue;
        dfs(i,now);
        e[++ecnt]={now,dep[now]}; cmax(lst[now],ecnt);cmin(fir[now],ecnt);
    }
}
int lca(int l,int r){
    if(l==r)return l;
    if(fir[l] > fir[r]) swap(l, r);
    if(lst[l] > lst[r]) return l;
    return st.query(fir[l], lst[r]).x;
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    //freopen();
    //freopen();
    //time_t start = clock();
    //think twice,code once.
    //think once,debug forever.
    int n,m,s;cin>>n>>m>>s; f(i,1,n)fir[i]=inf; 
    f(i,1,n-1){int u,v;cin>>u>>v;t[u].push_back(v);t[v].push_back(u);}
    dfs(s, 0); st.n = ecnt;
    st.build();f(i,1,m){int u,v;cin>>u>>v; cout<<lca(u,v)<<endl;}
    //time_t finish = clock();
    //cout << "time used:" << (finish-start) * 1.0 / CLOCKS_PER_SEC <<"s"<< endl;
    return 0;
}
/*
2023/x/xx
start thinking at h:mm


start coding at h:mm
finish debugging at h:mm
*/

DFS 序求 lca

考虑 \(x, y\) 互相不是祖先关系的时候,他们的 lca 不在他们的 dfn 之间。但是 \(y\) 的 lsa 是在的。所以考虑 \(rnk_x, rnk_y\) 区间范围内的 \(dep\) 最小点的父亲。

考虑 \(x\)\(y\) 祖先的时候。这时候 \(rnk_x + 1, rnk_y\) 这个区间仍然满足条件。

所以我们可以对 dfn 序下手,考虑 \(rnk_x+1, rnk_y\) 这个区间内求 rmq,找 \(dep\) 最小的那个数。然后唯一一个需要特判的地方是 \(x=y\)

#include<bits/stdc++.h>
using namespace std;
#define int long long
//use ll instead of int.
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const int inf = 1e9;
//#define cerr if(false)cerr
//#define freopen if(false)freopen
#define watch(x) cerr  << (#x) << ' '<<'i'<<'s'<<' ' << x << endl
void pofe(int number, int bitnum) {
    string s; f(i, 0, bitnum) {s += char(number & 1) + '0'; number >>= 1; } 
    reverse(s.begin(), s.end()); cerr << s << endl; 
    return;
}
template <typename TYP> void cmax(TYP &x, TYP y) {if(x < y) x = y;}
template <typename TYP> void cmin(TYP &x, TYP y) {if(x > y) x = y;}
//调不出来给我对拍!
//use std::array.
vector<int> g[500200];
int dep[500200], dfn[500200], rnk[500200], st[500200][25];
int dcnt=0;int lg2[500200]; int fa[500200];
void dfs(int x,int k) {
    dep[x]=dep[k]+1; fa[x]=k; rnk[x]=++dcnt; dfn[dcnt] = x;
    for(int i:g[x])if(i!=k)dfs(i,x);
}
int query(int l,int r){
    int k=lg2[r-l+1];
    if(dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]])return st[l][k];
    else return st[r-(1<<k)+1][k];
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    //freopen();
    //freopen();
    //time_t start = clock();
    //think twice,code once.
    //think once,debug forever.
    int n,m,s;cin>>n>>m>>s;
    f(i,2,n)lg2[i]=lg2[i>>1]+1;
    // f(i,1,n)cout<<lg2[i]<<" ";
    // cout<<endl;
    f(i,1,n-1){
        int x,y;cin>>x>>y;
        g[x].push_back(y);g[y].push_back(x);
    }
    dfs(s,0);
    // f(i,1,n)cout<<rnk[i]<<" ";
    // cout<<endl;
    for(int i=n;i>=1;i--){
        st[i][0]=dfn[i];
        // cout << st[i][0] << " ";
        for(int k=1;k<=lg2[n-i+1];k++) {
            if(dep[st[i][k-1]]<=dep[st[i+(1<<(k-1))][k-1]])st[i][k]=st[i][k-1];
            else st[i][k]=st[i+(1<<(k-1))][k-1];
            // cout << st[i][k] << " ";
        }
        // cout<<endl;
    }
    f(i,1,m){
        int a,b;cin>>a>>b;if(a==b)cout<<a<<endl;
        else {
            if(rnk[a] > rnk[b]) swap(a,b);
            cout<<fa[query(rnk[a]+1, rnk[b])]<<endl;
        }
    }
    //time_t finish = clock();
    //cout << "time used:" << (finish-start) * 1.0 / CLOCKS_PER_SEC <<"s"<< endl;
    return 0;
}
/*
2023/x/xx
start thinking at h:mm


start coding at h:mm
finish debugging at h:mm
*/

上次 u 群听说 log2 很慢,改一下,改成预处理,然后记住 \(\log_2 1 = 0\),所以是这么写的:
f(i,2,n)lg2[i]=lg2[i>>1]+1;

posted @ 2022-07-11 16:59  OIer某罗  阅读(77)  评论(0编辑  收藏  举报