Loading

AtCoder Beginner Contest 267 F Exactly K Steps

Exactly K Steps

树的直径 + 离线

考虑离每个点最远的点是哪个点,根据树的直径相关内容,距离该点最远的点一定是直径两个端点之一,理解上的话可以考虑画一个图,直径上的所有点在一条直线上,然后再花分支,如果其中一个点的最远距离的点,不是直径上两个端点的话,显然会有新的更长的直径产生

因为知道了最远距离的那个端点,因此如果有答案,那么答案一定在该点到最远的那个点的路径上

因此就先用 \(bfs\) 找到直径的两个端点

然后离线询问,以两个端点分别做一次 \(dfs\),并维护下 dfs 栈,然后对每个点的询问在 \(dfs\) 栈上找

#include <iostream>
#include <cstdio>
#include <vector>
#include <queue>
#include <algorithm>
#include <array>
using namespace std;
const int maxn = 2e5 + 10;
vector<int>gra[maxn], ans(maxn, -1);
vector<array<int, 2> >Q[maxn];
int n, q;

int bfs(int now)
{
    queue<int>q;
    vector<int>vis(n + 1, 0);
    q.push(now);
    int last = 0;
    while(q.size())
    {
        int now = q.front();
        q.pop();
        if(vis[now]) continue;
        vis[now] = 1;
        last = now;
        for(int nex : gra[now])
        {
            if(vis[nex]) continue;
            q.push(nex);
        }
    }
    return last;
}

int st[maxn], tp = 0;
void dfs(int now, int pre)
{
    st[++tp] = now;
    for(auto [k, id] : Q[now])
    {
        if(k >= tp) continue;
        ans[id] = st[tp - k];
    }
    for(int nex : gra[now])
    {
        if(nex == pre) continue;
        dfs(nex, now);
    }
    st[tp--] = 0;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin >> n;
    for(int i=1; i<n; i++)
    {
        int a, b;
        cin >> a >> b;
        gra[a].push_back(b);
        gra[b].push_back(a);
    }
    cin >> q;
    for(int i=0; i<q; i++)
    {
        int u, k;
        cin >> u >> k;
        Q[u].push_back({k, i});
    }
    int L = bfs(1);
    int R = bfs(L);
    dfs(L, -1);
    dfs(R, -1);
    for(int i=0; i<q; i++)
        cout << ans[i] << "\n";
    return 0;
}

赛中我写了个非常麻烦的做法,不建议理解

随便产生一个树,对于一个点,他的最远距离点无非就是祖先,或者某个祖先往下的最深,以及本身的最深

因此用 \(dfs\),维护每个点往下的最深和次深以及点的深度(用户处理答案在祖先的情况),然后再进行一次 \(dfs\),维护答案在某个祖先作为 \(LCA\),然后再往下的情况

这样就能处理出最远的点,接着用倍增求第 \(k\) 级祖先的方式,维护距离的差(最远的点的距离 和 询问的距离)

时间复杂度: \(O(nlogn)\)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#include <queue>
#include <functional>
#include <map>
#include <set>
#include <cmath>
#include <cstring>
#include <deque>
#include <stack>
#include <ctime>
#include <cstdlib>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
const ll maxn = 2e5 + 10;
const ll inf = 1e17 + 10;
vector<int>gra[maxn];
int dep[maxn], fa[maxn][25];
pii sa[maxn], sb[maxn];
pii last[maxn];

void dfs(int now, int pre, int d)
{
    dep[now] = d;
    fa[now][0] = pre;
    sa[now] = {now, 0};
    for(int nex : gra[now])
    {
        if(nex == pre) continue;
        dfs(nex, now, d + 1);
        pii temp = sa[nex];
        temp.second += 1;
        if(temp.second > sa[now].second)
        {
            sb[now] = sa[now];
            sa[now] = temp;
        }
        else if(sa[nex].second > sb[now].second)
            sb[now] = temp;
    }
}

void dfs2(int now, int pre, pii maxx)
{
    last[now] = sa[now];
    maxx.second++;
    if(maxx.second > sa[now].second)
        last[now] = maxx;
    for(int nex : gra[now])
    {
        if(nex == pre) continue;
        pii temp = maxx;
        if(sa[nex].first != sa[now].first)
        {
            if(temp.second < sa[now].second)
                temp = sa[now];
            dfs2(nex, now, temp);
        }
        else
        {
            if(temp.second < sb[now].second)
                temp = sb[now];
            dfs2(nex, now, temp);
        }
    }
}

int kth(int now, int k)
{
    for(int i=20; i>=0; i--)
    {
        if(k >= (1 << i))
        {
            k -= 1 << i;
            now = fa[now][i];
        }
    }
    return now;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin >> n;
    for(int i=1; i<n; i++)
    {
        int a, b;
        cin >> a >> b;
        gra[a].push_back(b);
        gra[b].push_back(a);
    }
    dfs(1, 1, 0);
    for(int i=1; i<=20; i++)
        for(int j=1; j<=n; j++)
            fa[j][i] = fa[fa[j][i-1]][i-1];
    dfs2(1, 1, {1, -1});
    int q;
    cin >> q;
    while(q--)
    {
        int a, b;
        cin >> a >> b;
        int ans = -1;
        if(b <= dep[a])
            ans = kth(a, b);
        else if(b <= last[a].second)
            ans = kth(last[a].first, last[a].second - b);
        cout << ans << "\n";
    }
    return 0;
}
posted @ 2022-09-05 18:01  dgsvygd  阅读(44)  评论(0编辑  收藏  举报