树的直径

找树的直径:三次 dfs 法。第一次随便以一个点为根找到和它距离最远的一个,这个点一定是直径的一个端点。然后以这个点为根找到直径的另一个端点。然后顺藤摸瓜找到整条链。

void dfs(int now, int fa) {
    dep[now] = dep[fa] + 1;
    f(i, 0, (int)g[now].size()-1){
        if(g[now][i]!=fa)dfs(g[now][i],now);
    }
}
bool fl(int now, int fa){
    if(now == tail) {
        dg.push_back(now);
        return 1;
    }
    bool isd = 0;
    f(i, 0, (int)g[now].size() - 1) 
        if(g[now][i]!=fa)
            if(fl(g[now][i],now)) isd = 1;
    if(isd) dg.push_back(now);
    return isd;
}
signed main() {
    dfs(1, 0);
    int mx=0;f(i,1 ,n){if(dep[i]>mx){mx=dep[i];head=i;}}
    dfs(head,0);
    mx=0;f(i,1 ,n){if(dep[i]>mx){mx=dep[i];tail=i;}}
    fl(head, 0);
}

这个时候 \(dg\) 数组就是从 \(tail\)\(head\) 存储直径上的点了。

树的直径有一些性质,先来看一道经典题目:

ABC267F

对于一棵树有 \(q\) 次询问,每次询问任意输出一个和 \(u_i\) 距离为 \(k_i\) 的点。

分析:直径的性质是:图上和 \(u_i\) 最远的点是直径的一个端点。那么只要在该点和直径上离它较远的那个点的连线上找点即可。

怎么找呢?把这个连线分成两端,一段上的点非直径,另一段上的点是直径。非直径的可以用倍增做到 \(O(\log k)\) 求第 \(k\) 级祖先,直径的只需要在 \(dg\) 数组中找到对应序号的即可。

注意维护倍增祖先表的时候要从直径向叶子维护,而不是从 \(1\)\(n\)

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define f(i, a, b) for(int i = (a); i <= (b); i++)
#define cl(i, n) i.clear(),i.resize(n);
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
const int inf = 1e9;
int n;
int dep[200010];
int dis[200010];
vector<int> g[200010];
int head, tail;
vector<int> dg;
int sx[200010];
int anc[200010][30];  //叶子节点 i 的第 2^j 级祖先
int num[200010];//归属于直径上的哪个点
void dfs(int now, int fa) {
    dep[now] = dep[fa] + 1;
    f(i, 0, (int)g[now].size()-1){
        if(g[now][i]!=fa)dfs(g[now][i],now);
    }
}
bool fl(int now, int fa){
    if(now == tail) {
        dg.push_back(now);
        return 1;
    }
    bool isd = 0;
    f(i, 0, (int)g[now].size() - 1) {
        if(g[now][i]!=fa){
            if(fl(g[now][i],now)) isd = 1;
            
        }
    }
    if(isd) dg.push_back(now);
    return isd;
}
void bz(int now) {
    int k = dis[now];
    if(k == 0) return;
    f(i, 1, log2(k)+1) {
        anc[now][i] = anc[anc[now][i - 1]][i - 1];
    }
}
void zhi(int now, int fa, int lzy) {
    num[now] = lzy;
    if(sx[now]) dis[now] = 0;
    else dis[now] = dis[fa] + 1;
    bz(now);
    f(i, 0, (int)g[now].size()- 1) {
        if(sx[g[now][i]] || g[now][i] == fa) continue;
        anc[g[now][i]][0] = now;
        zhi(g[now][i], now, lzy);
        
    }
    return;
}
void solve(int x, int d) {
    int mx = dis[x] + max(abs(1 - sx[num[x]]), abs((int)dg.size() - sx[num[x]]));
    if(d > mx) cout << -1 << endl;
    else if(d <= dis[x]) {
        //求x 的第 d 级祖先
        int ntf = 0, noi = x;
        while(d) {
            if(d & 1) noi = anc[noi][ntf];
            ntf++;
            d >>= 1;
        }
        cout << noi << endl;
    }
    else {
        int dd = d - dis[x];
        if(abs(1 - sx[num[x]]) >= dd) cout << dg[sx[num[x]] - dd - 1] << endl;
        else cout << dg[sx[num[x]] + dd - 1] << endl;
    }
    return;
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    time_t start = clock();
    //think twice,code once.
    //think once,debug forever.
    cin >> n;
    f(i, 1, n - 1) {
        int a, b; cin >> a >> b;
        g[a].push_back(b),g[b].push_back(a);
    }
    dfs(1, 0);
    int mx=0;f(i,1 ,n){if(dep[i]>mx){mx=dep[i];head=i;}}
    dfs(head,0);
    mx=0;f(i,1 ,n){if(dep[i]>mx){mx=dep[i];tail=i;}}
    fl(head, 0);
    int cnt = 0;for(int i : dg) sx[i] = ++cnt;
	for(int i : dg) zhi(i, 0, i);
	//f(i, 1, n) bz(i);  //这是错误的
	cout << endl;
    int q; cin >> q;
    f(i,1 , q) {
        int u, k; cin >> u >> k;
        solve(u, k);
    }
    time_t finish = clock();
    //cout << "time used:" << (finish-start) * 1.0 / CLOCKS_PER_SEC <<"s"<< endl;
    return 0;
} 
posted @ 2022-10-05 16:37  OIer某罗  阅读(40)  评论(0编辑  收藏  举报