CF324E Xenia and Tree

题目传送门


题意

有一颗树,初始\(1\)号点为红色,其余为蓝色,有两种操作:1.修改一个蓝点为红点。2.查询每个点最近红点的距离


题解

应该是较板题
最近根号写多了,首先想到的是: 每根号个询问重构一次,不超过根号就暴力枚举这个点和那根号个点的距离。
发现可能是对的,dp可以在\(O(n)\)复杂度处理出答案,每\(O(\sqrt{n})\) dp一次即可。
考虑需要\(O(1)\)LCA,于是学了一手。
如何dp?对于每个点,先dp它的儿子,然后他的答案就等于它儿子或父亲中答案最小的加1.
注意,要调用两遍这个dp。
你可以这样理解,这是一个混合dp,它实际上分两步进行:
第一步: 处理出每个节点向下最近的距离,这个可以从下往上dp一遍,
第二步:处理每个点向上经过它父亲的最近距离,由于根节点不能向上,所以答案不变,从上往下dp是对的。
我比较懒,写在一起, 调用两遍。

实现

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>
using namespace std;

int read(){
    int num=0, flag=1; char c=getchar();
    while(!isdigit(c) && c!='-') c=getchar();
    if(c == '-') c=getchar(), flag=-1;
    while(isdigit(c)) num=num*10+c-'0', c=getchar();
    return num*flag;
}

int min(int a, int b){return a<b?a:b;}

const int N = 2e6+100;
const int M = 21;
const int inf = 0x3f3f3f3f;
int n, m, sqr, fa[N], dep[N], lg[N], dfn[N], f[N][M], tot=0;
int col[N], ans[N];
vector<int> p[N];
vector<int> op;

void dfs(int x){
    dfn[x] = ++tot, dep[x] = dep[fa[x]] + 1;
    f[tot][0] = x;
    for(auto i : p[x]){
        if(i == fa[x]) continue;
        fa[i] = x;
        dfs(i);
        f[++tot][0] = x;
    }
}

int mindep(int x, int y){
    return dep[x]<dep[y]?x:y;
}

void pre(){
    for(int i=2; i<N; i++) lg[i] = lg[i>>1] + 1;
    for(int i=1; i<M; i++){
        for(int j=1; j<=tot; j++){
            f[j][i] = mindep(f[j][i-1], f[j+(1<<(i-1))][i-1]);
        }
    }
}

int rmq(int l, int r){
	if(l > r) swap(l, r); 
    return mindep(f[l][lg[r-l+1]], f[r-(1<<lg[r-l+1])+1][lg[r-l+1]]);
}

int lca(int x, int y){
    return rmq(dfn[x], dfn[y]);
}

int getDist(int x, int y){
    return dep[x]+dep[y] - 2*dep[lca(x, y)];
}

void solve(int x){
    if(col[x]) ans[x]=0;
    ans[x] = min(ans[x], ans[fa[x]]+1);
    for(auto i : p[x]){
        if(i == fa[x]) continue;
        solve(i);
        ans[x] = min(ans[x], ans[i]+1);
    }
}

int main(){
    n=read(), m=read(), sqr=sqrt(m); 
    for(int i=1; i<=n; i++) ans[i]=inf;
    for(int i=1; i<n; i++){
        int u=read(), v=read();
        p[u].push_back(v), p[v].push_back(u);
    }
    dfs(1); pre();

    op.push_back(1);
    col[1] = 1;
    while(m--){
        int type=read(), x=read();
        if(type == 1){
            op.push_back(x);
            col[x] = 1;
        }else{
            if(op.size() > sqr){
                solve(1);
                solve(1); 
                op.clear();
            }

            for(auto i : op){
                ans[x] = min(ans[x], getDist(x, i));
            }

            printf("%d\n", ans[x]);
        }
    }
    return 0;
}
posted @ 2022-02-27 09:19  ltdJcoder  阅读(21)  评论(0编辑  收藏  举报