树上问题



// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cstdlib>
#include<cmath>
using namespace std;
const int N  = 3e5+1e3;
int read(){
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')q=q*10+ch-'0',ch=getchar();
    return q;
}
struct Edge{int next,to;}e[N<<1];
int n,k,last[N],edge_number;
int size[N],fa[N],id[N],low[N],q[N],tot,d[N],top[N],kkk[N];
vector<pair<int,int > >v1[N],v2[N];
int cov[N<<2];
pair<int,int> sum[N<<2];
void add(int a,int b){
    e[++edge_number].next=last[a],last[a]=edge_number;e[edge_number].to=b;
}
void dfs1(int x,int f){
    size[x]=1;fa[x]=f;d[x]=d[f]+1;
    for(int i=last[x];i;i=e[i].next){
        if(e[i].to==f) continue;
        dfs1(e[i].to,x);
        size[x]+=size[e[i].to];
        if(size[e[i].to]>size[kkk[x]]) kkk[x]=e[i].to;
    }
}
void dfs2(int x,int topf){
    id[x]=++tot,q[tot]=x;
    top[x]=topf;
    if(kkk[x]) dfs2(kkk[x],topf);
    for(int i=last[x];i;i=e[i].next){
        if(e[i].to==fa[x]||e[i].to==kkk[x]) continue;
        dfs2(e[i].to,e[i].to);
    }
    low[x]=tot;
}
int lca(int a,int b){
    while(top[a]!=top[b]){
        if(d[top[a]]<d[top[b]]) swap(a,b);
        a=fa[top[a]];
    }
    return d[a]>d[b]?b:a; 
}
void insert(int a,int b,int c,int d){
    v1[a].push_back(make_pair(c,d));
    v2[b+1].push_back(make_pair(c,d));
}
int getanc(int x,int ddd){
    while(d[top[x]]>ddd+1)
        {x=fa[top[x]];}
    return q[id[top[x]]+ddd+1-d[top[x]]];
}
void deal(int a,int b){
    int z=lca(a,b);
    if(z!=a&&z!=b){
        insert(id[a],low[a],id[b],low[b]);
        insert(id[b],low[b],id[a],low[a]);
    }
    else{
        if(b==z) std::swap(a,b);
        a=getanc(b,d[a]);
        if(id[a]>1){
            insert(1,id[a]-1,id[b],low[b]);
            insert(id[b],low[b],1,id[a]-1); 
        }
        if(low[a]<n){
            insert(low[a]+1,n,id[b],low[b]);
            insert(id[b],low[b],low[a]+1,n);
        }
    }
}
void pushup(int cur){
    if(sum[cur<<1].first==sum[cur<<1|1].first){
        sum[cur].first=sum[cur<<1].first;
        sum[cur].second=sum[cur<<1].second+sum[cur<<1|1].second;
    }
    else{
        sum[cur]=max(sum[cur<<1],sum[cur<<1|1]);
    }
}
void build(int cur,int l,int r){
    if(l==r){
        sum[cur].first=1;
        sum[cur].second=1;
        return;
    }
    int mid=(l+r)>>1;
    build(cur<<1,l,mid);
    build(cur<<1|1,mid+1,r);
    pushup(cur);
}
void update(int cur,int v){
    cov[cur]+=v;
    sum[cur].first+=v;
}
void pushdown(int cur){
    if(cov[cur]!=0){
        update(cur<<1,cov[cur]);
        update(cur<<1|1,cov[cur]);
        cov[cur]=0;
    }
}

void add(int cur,int l,int r,int L,int R,int v){
    if(L<=l&&r<=R){
        sum[cur].first+=v;
        cov[cur]+=v;
        return ;
    }
    pushdown(cur);
    int mid=(l+r)>>1;
    if(L<=mid) add(cur<<1,l,mid,L,R,v);
    if(R>mid) add(cur<<1|1,mid+1,r,L,R,v);
    pushup(cur);
}
int main(){
    n=read(),k=read();
    for(int i=1;i<n;++i){
        int a=read(),b=read();
        add(a,b),add(b,a);
    }
    dfs1(1,0);
    dfs2(1,1);
    for(int i=1;i<=n;++i){
        for(int j=i+1;j<=std::min(n,i+k);++j){
            deal(i,j);
        }
    }
    build(1,1,n);
    long long ans=0;
    for(int i=1;i<=n;++i){
        int S=v1[i].size();
        for(int j=0;j<S;++j){
            add(1,1,n,v1[i][j].first,v1[i][j].second,-1);
        }
        S=v2[i].size();
        for(int j=0;j<S;++j){
            add(1,1,n,v2[i][j].first,v2[i][j].second,1);
        }
        if(sum[1].first==1){
            ans+=sum[1].second;
        }
    }
    printf("%lld",(ans+n)/2);
}
posted @ 2018-10-15 09:57  风浔凌  阅读(198)  评论(0编辑  收藏  举报