树的点分治讲解

树的点分治讲解

了解分治

  在学习树的点分治之前,我们需要了解一下分治,分治之中最经典的就是序列分治,看到这篇博客的各位应该都会序列分治,大致思想就是每一次把整个序列划分成几个小区间,并在这些小区间里面继续分治下去,最后再把这下小区间一一合并,处理结果与贡献,最后得出结果,这就是分治。

  在了解序列分治之后,我们就可以有讨论一下树的点分治,为什么叫树的点分治呢?本人理解,就是每一次将一棵树分化成不同的子树进行处理,从而使问题简单化。(如下图)可能有人会问,为什么下图中会选择三号点?这就是我下面要讲解的问题。

选择分划点

  大家仔细观察上面的图片会发现,三号点是整棵树的重心,又有人可能会问,什么是树的重心?树的重心可以理解成为,我们把一棵树由他的重心进行划分,分成的所有子树的大小中的最大是最小的,例如上图,3号点为重心,由它分划之后的最大的大小是4,如果由1号点进行分划,最大的大小是5,显然没有3号点小。因此我们选择了三号点。但是为什么每一次都要选择树的重心呢?如果我们每一次都处理出来尽可能多的信息,这样在最后合并的时候,或者是说统计结果是的所用步数就越小,如果每一次我们分的子树的时候所有的子树的大小都十分接近,那么我们程序运行的时间就一定小。这个时候我们查找树的重心就十分重要了。

查找重心

  我们查找的时候,每一个点都可以维护出来一个f数组,这个数组代表的是把当前点作为分划点之后,产生的所有的子树的最大大小,根据这个数组的定义就可以知道如何求出来,我们每一次把f数组都赋值成为零,再用max函数求出当前点所有的字数的大小中最大的是不是就可以了?当然我们每一次求的时候不要忘记把f数组整体大小减去当前整棵树的大小的差进行比较。

void get_root(int p,int from)
{
    size[p]=1;
    f[p]=0;
    for(int i=head[p];i;i=nxt[i])
        if((!vis[to[i]])&&to[i]!=from)    //由于在点分治过程之中,有的点已经搜索过,所以就不用在一次进行处理
        {
            get_root(to[i],p);
            size[p]+=size[to[i]];
            f[p]=max(f[p],size[to[i]]);
        }
    f[p]=max(f[p],all-size[p]);
    if(f[root]>f[p]) root=p;
}
int main()
{
    root=0,f[0]=n+1,all=n;
    get_root(1,0);                          //找整棵树的重心
    root=0,f[0]=n+1,all=size[p];
    get_root(p,0);                         //找以p节点为根的子树的重心  
}

求答案

  点分治之中最重要的是求答案的函数,在点分治的题目之中最常见的是求树上的长度的问题,我们每一次可以在划分点处进行统计答案,下面以求树上路径长度终小于等于k的个数有多少个。我们看下面的图解。我们是不是可以很容易求出来划分点到他的每一个点的距离,那么我们就可以求出像图中的一条路径长度,很显然就是一个子树中的点到划分点的距离加上另一颗子树中的点到划分点的距离之和,即为长度。这样我们每一回就可以求出以当前划分点为一个中间点的路径的长度,并进行统计。但是当然,如果我们每一次加和统计是不是有一些慢?所以我们每一次可以把所有的点的距离放在一起,O(n*logn)进行排序,并用双指针扫一下,是不是时间就降下来了?但是又有一个问题,如下右图的情况怎么办?这样是不是不对啊?这样很明显是多求出来了,但是我们可以用单步容斥的思想,我们每一次划分点的子树的时候减去下面的情况是不是就好了?如果还是不理解,我建议看下面代码来理解,毕竟我的语文功底很差。

void get_dis(int p,int from)
{
    dist[++tot]=dis[p];
    for(int i=head[p];i;i=nxt[i])
        if(to[i]!=from&&(!vis[to[i]]))
        {
            dis[to[i]]=dis[p]+val[i];
            get_dis(to[i],p);
        }
}
int calc(int p)
{
    tot=0,get_dis(p,0);
    sort(dist+1,dist+tot+1);
    int l=1,r=tot,sum=0;
    while(l<r)
    {
        if(dist[l]+dist[r]<=m)
            sum+=r-l,l++;
        else r--;
    }
    return sum;
}
void dfs(int p,int from)
{
    vis[p]=true,dis[p]=0;
    ans+=calc(p);
    for(int i=head[p];i;i=nxt[i])
        if(!vis[to[i]])
        {
            dis[to[i]]=val[i];
            ans-=calc(to[i]);
            root=0,all=size[to[i]];
            get_root(to[i],0),dfs(root,0);
        }
}

  在代码之中我又加上了dfs这个数组,其中每一次get_root就是每一次啊找当前划分点的子树之中的划分点,递归下去。

  如果有不会的,或是问题,可以发评论进行提问,我会解答。

poj1741&&bzoj1468就是求在一棵树之中的所有路径有多少条小于等于k的存在,这也就是我讲解的例题,代码如下。

#include <stdio.h>
#include <algorithm>
using namespace std;
#define N 40001
int n,m;
int head[N];
int to[N<<1];
int nxt[N<<1];
int val[N<<1];
int f[N],size[N];
int dis[N];
int dist[N];
int idx,root,all,ans,tot;
bool vis[N];
void add(int a,int b,int c)
{
    nxt[++idx]=head[a];
    head[a]=idx;
    to[idx]=b;
    val[idx]=c;
}
void get_root(int p,int from)
{
    f[p]=0,size[p]=1;
    for(int i=head[p];i;i=nxt[i])
        if(to[i]!=from&&(!vis[to[i]]))
        {
            get_root(to[i],p);
            size[p]+=size[to[i]];
            f[p]=max(f[p],size[to[i]]);
        }
    f[p]=max(f[p],all-size[p]);
    if(f[root]>f[p]) root=p;
}
void get_dis(int p,int from)
{
    dist[++tot]=dis[p];
    for(int i=head[p];i;i=nxt[i])
        if(to[i]!=from&&(!vis[to[i]]))
        {
            dis[to[i]]=dis[p]+val[i];
            get_dis(to[i],p);
        }
}
int calc(int p)
{
    tot=0,get_dis(p,0);
    sort(dist+1,dist+tot+1);
    int l=1,r=tot,sum=0;
    while(l<r)
    {
        if(dist[l]+dist[r]<=m)
            sum+=r-l,l++;
        else r--;
    }
    return sum;
}
void dfs(int p,int from)
{
    vis[p]=true,dis[p]=0;
    ans+=calc(p);
    for(int i=head[p];i;i=nxt[i])
        if(!vis[to[i]])
        {
            dis[to[i]]=val[i];
            ans-=calc(to[i]);
            root=0,all=size[to[i]];
            get_root(to[i],0),dfs(root,0);
        }
}
int main()
{
    scanf("%d",&n);
    int a,b,c;
    root=ans=idx=0,f[0]=n+1,all=n;
    if(n==m&&n==0) return 0;
    for(int i=1;i<n;i++)
    {
        scanf("%d%d%d",&a,&b,&c);
        add(a,b,c),add(b,a,c);
    }
    scanf("%d",&m);
    get_root(1,0),dfs(root,0);
    printf("%d\n",ans);
    for(int i=1;i<=n;i++)
        vis[i]=false,head[i]=0;
}
posted @ 2018-05-15 15:38  Yang1208  阅读(1097)  评论(0编辑  收藏  举报