Tree

Tree

参考 xk 老哥的博客:POJ 1741 Tree 点分治

找重心:

void getrt(int fa,int u,int num)    //num指的是这个节点的子树中有多少个节点
{
    siz[u]=1;
    int maxnum=0;                   //记录最大子树的节点个数
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]||v==fa) continue;
        getrt(u,v,num);
        siz[u]+=siz[v];
        maxnum=max(maxnum,siz[v]);
    }
    maxnum=max(maxnum,num-siz[u]);  //num-siz[u]表示的是某节点的反向子树(反向指的是沿着递归方向反方向)
    if(maxnum<rtsiz) rtsiz=maxnum,rt=u;  //更新重心和重心的最大子树
}

找到重心之后 dfs 计算子树上每个点距离重心的距离:

int d[maxn],dcnt;
void dfs(int fa,int u,int w)
{
    d[++dcnt]=w;
    siz[u]=1;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]||fa==v) continue;
        dfs(u,v,w+e[i].w);
        siz[u]+=siz[v];
    }
}

根据每个点到重心的距离进行排序,并计算有多少满足条件的点:

int cal()
{
    sort(d+1,d+1+dcnt);
    int l=1,r=dcnt,ret=0;
    while(l<r)
        if(d[l]+d[r]<=m) ret+=r-l,l++;
        else r--;
        return ret;
}

重点来了!分治

void solve(int u,int num)
{
    if(num<=1) return;
    rtsiz=inf;
    getrt(0,u,num);
    vis[rt]=1;
    dcnt=0;
    dfs(0,rt,0);                          
    ans+=cal();
    /*第一次cal的时候可能会把根节点的同一颗子树上的两个点d[l]+d[r]<=m记录进去,
    但实际上这样的两个点不应该在本次cal中统计进去,因为在后面的分治中会进行删去
    本次cal与下面for中的cal加起来统计的是经过该点u的满足条件的路径*/
    for(int i=head[rt];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]) continue;
        dcnt=0;
        dfs(0,v,e[i].w);
        ans-=cal();
        /*在这个地方把不经过点u但满足条件的路径进行删除*/
    }
    for(int i=head[rt];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]) continue;
        solve(v,siz[v]);
        /*进行分治,也就是向下进行查询*/
    }
}

代码:

// Created by CAD on 2019/8/14.
#include <iostream>
#include <cstdio>
#include <algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=1e5+100;
int siz[maxn],vis[maxn],head[maxn],tot;
int n,m,ans;
struct edge{
    int to,next,w;
}e[maxn<<1];
void add(int u,int v,int w)
{
    e[++tot].to=v;
    e[tot].w=w;
    e[tot].next=head[u];
    head[u]=tot;
}
int rtsiz,rt;
void getrt(int fa,int u,int num)
{
    siz[u]=1;
    int maxnum=0;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]||v==fa) continue;
        getrt(u,v,num);
        siz[u]+=siz[v];
        maxnum=max(maxnum,siz[v]);
    }
    maxnum=max(maxnum,num-siz[u]);
    if(maxnum<rtsiz) rtsiz=maxnum,rt=u;
}
int d[maxn],dcnt;
void dfs(int fa,int u,int w)
{
    d[++dcnt]=w;
    siz[u]=1;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]||fa==v)continue;
        dfs(u,v,w+e[i].w);
        siz[u]+=siz[v];
    }
}
int cal()
{
    sort(d+1,d+1+dcnt);
    int l=1,r=dcnt,ret=0;
    while(l<r)
        if(d[l]+d[r]<=m) ret+=r-l,l++;
        else r--;
        return ret;
}
void solve(int u,int num)
{
    if(num<=1) return;
    rtsiz=inf;
    getrt(0,u,num);
    vis[rt]=1;
    dcnt=0;
    dfs(0,rt,0);
    ans+=cal();
    for(int i=head[rt];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]) continue;
        dcnt=0;
        dfs(0,v,e[i].w);
        ans-=cal();
    }
    for(int i=head[rt];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(vis[v]) continue;
        solve(v,siz[v]);
    }
}
int main()
{
    int u,v,w;
    while(~scanf("%d%d",&n,&m)&&n+m)
    {
        tot=ans=0;
        for(int i=1;i<=n;++i)
            vis[i]=0,head[i]=-1;
        for(int i=1;i<n;++i)
            scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
        solve(1,n);
        cout<<ans<<endl;
    }
    return 0;
}
posted @ 2019-08-14 22:23  caoanda  阅读(223)  评论(0编辑  收藏  举报