CF161D Distance in Tree(点分治)

点分治是一种处理树的优秀暴力

这是一道板子题

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int u[50010<<1],v[50010<<1],fir[50010],nxt[50010<<1],cnt,root,sz[50010],f[50010],middis[50010],midcnt,n,k,vis[50010],Siz;
long long ans=0;
void addedge(int ui,int vi){
    ++cnt;
    u[cnt]=ui;
    v[cnt]=vi;
    nxt[cnt]=fir[ui];
    fir[ui]=cnt;
}
void getroot(int u,int fa){
    sz[u]=1,f[u]=1;
    for(int i=fir[u];i;i=nxt[i]){
        if(v[i]==fa||vis[v[i]])
            continue;
        getroot(v[i],u);
        sz[u]+=sz[v[i]];
        f[u]=max(f[u],sz[v[i]]);
    }
    f[u]=max(Siz-sz[u],f[u]);
    if(f[u]<f[root])
        root=u;
}
void getdis(int u,int d,int fa){
    // printf("ux=%d\n",u);
    middis[++midcnt]=d;
    for(int i=fir[u];i;i=nxt[i]){
        if(vis[v[i]]||v[i]==fa)
            continue;
        getdis(v[i],d+1,u);
    }
}
int look1(int l,int k){
    int ans=0,r=midcnt;
    while(l<=r){
        int mid=(l+r)>>1;
        if(middis[mid]<k)
            l=mid+1;
        else
            ans=mid,r=mid-1;
    }
    return ans;
}
int look2(int l,int k){
    int ans=0,r=midcnt;
    while(l<=r){
        int mid=(l+r)>>1;
        if(middis[mid]<=k)
            l=mid+1,ans=mid;
        else
            r=mid-1;
    }
    return ans;
}
int solve(void){
    sort(middis+1,middis+midcnt+1);
    // for(int i=1;i<=midcnt;i++)
    //     printf("%d ",middis[i]);
    // getchar();
    // printf("\n");
    int mid=0;
    int l=1;
    while(l<midcnt&&middis[l]+middis[midcnt]<k)
        ++l;
    while(l<midcnt&&k-middis[l]>=middis[l]){
        int l2=look2(l+1,k-middis[l]),l1=look1(l+1,k-middis[l]);
        if(l2>=l1)
            mid+=l2-l1+1;
        l++;
    }
    return mid;
}
void divide(int u){
    // printf("u=%d\n",u);
    // getchar();
    vis[u]=true;
    midcnt=0;
    getdis(u,0,0);
    // printf("ok\n");
    ans+=solve();
    // printf("an=%d\n",ans);
    for(int i=fir[u];i;i=nxt[i]){
        if(vis[v[i]])
            continue;
        midcnt=0;
        getdis(v[i],1,0);
        ans-=solve();
        // printf("s=%d\n",ans);
        root=0;
        Siz=sz[v[i]];
        getroot(v[i],u);
        divide(root);
    }
}
int main(){
    scanf("%d %d",&n,&k);
    for(int i=1;i<=n-1;i++){
        int a,b;
        scanf("%d %d",&a,&b);
        addedge(a,b);
        addedge(b,a);
    }
    Siz=n;
    f[0]=0x3f3f3f3f;
    getroot(1,0);
    divide(root);
    printf("%lld",ans);
    return 0;
}
posted @ 2018-12-09 19:20  dreagonm  阅读(1010)  评论(0编辑  收藏  举报