树分治

分治

分治作为一种思想在算法中被广泛使用,比如归并排序、线段树之类的算法都是使用了分治的思想。
除了这些算法的底层原理,分治单独拿出来运用的其中一种场景是作为一种统计类算法使用。

统计类问题一般指不带修改,静态查询类的问题。但是不同于静态查询问题,统计类问题一般要求“全查”,例如枚举所有的子区间,枚举一颗树上所有的链,枚举一颗树上所有的子树。
dsu on tree/长链剖,就是一种统计类算法(全查所有子树)。

由于统计类问题往往要求“全查”,这个时候各种数据结构就有点力不从心了,因为只要枚举例如数组的子区间,复杂度就要退化到了N^2级别。
要讲树分治,我们先从数组分治讲起

数组上的统计问题

我们复习一下数组上进行分治的原理:
1、左右端点全部落在mid左侧的
2、左右端点全部落在mid右侧的
3、左端点落在mid左侧,右端点落在mid右侧
每次处理③,对于①②,以递归函数的方式实现
然后我们需要维护的东西,是两个集合,一个表示前缀,一个表示后缀。
用meet in middle的方式合并“集合”。

经典例题:

给一个数组,求所有子数组的gcd之和。
这个是个经典问题,首先我们取整个数组的中点mid画一条线。
数组中所有的子区间只有三种:
1、左右端点全部落在mid左侧的
2、左右端点全部落在mid右侧的
3、左端点落在mid左侧,右端点落在mid右侧
每次处理③,对于①②,以递归函数的方式实现
例如当前处理区间为
1 3 6 12 4 8 3 5
左区间[1 3 6 12]和右区间 [4 8 3 5]已经递归处理完毕
考虑如何处理③:
左区间[1 3 6 12]的后缀gcd为A=[1 3 6 12]
右区间[4 8 3 5 ]的前缀gcd为B=[4 4 1 1 ]
那么枚举A和B中的任意一对的gcd便是左端点落在mid左侧,右端点落在mid右侧的gcd之和
但这是总时间复杂度为\(O(n*n^2)\),考虑如何优化
gcd有一个特性,例如A和B中所示,连续gcd一定是递减的,要不然是相等,不然就至少减半
那么我们每次对A和B去重,A和B是log级别的,用两个map分别储存每个gcd值的个数
这样枚举是\(log^2 n\)
总时间复杂度为\(O(n*log^2 n)\)

有没有感觉像线段树?
有就对了,线段树也是一种“分治”的过程。
但是不太一样的地方在于这种数组分治统计信息往往不是“区间信息”和并,可能是“集合”的信息。
这部分有点类似meet in middle(折半搜索)的意思在里面。

这里提前说一句,要学分治必须会这些东西,分治一般不单独用。
map,set,lower_bound,二分答案,双指针尺取法,单调栈,容斥。
一个分治要是不套几个这玩意它就不是分治了。
这就是为什么有的时候说分治难写。

树上分治

树分治有两种,一般常见的是点分治,还有一种不太常用的边分治,主要讲解点分治。

树的重心

定义无根树的重心为:选择重心作为整个树的根节点,使得整棵树变为有根树后,可以最小化根节点的所有直接子树中尺寸的最大值。
这个地方可以用一个树形DP简单求解,甚至不需要换根(当然换根法也能做)。
树的重心有一个重要性质:
选择树的重心作为根节点时,根节点的直接子树尺寸不大于N/2,N表示整个树的尺寸。
这个证明起来比较简单,用反证法就可以了。
假设树的重心作为根节点时存在一个节点的尺寸大于N/2,那么把根节点移动到该子树所在的根节点时一定更优。

选择重心以后树分治就按照数组分治中的思路做就可以了。
1、处理跨重心的树链。
2、递归处理子树。

模板

点击查看代码
struct Tree_Div{
    int tot,head[maxn],nx[maxn],to[maxn];
    //ll w[maxn];   边权
    void add(int x,int y,int z){
        to[++tot]=y;nx[tot]=head[x];head[x]=tot;w[tot]=z;
    }

    bool vis[maxn];             //是否是重心
    int root,sz[maxn],f[maxn];
    int sum;                    //当前子树大小
    void get_gravity(int u,int fa){
        sz[u]=1;f[u]=0;
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            get_gravity(v,u);f[u]=max(f[u],sz[v]);sz[u]+=sz[v];
        }
        f[u]=max(f[u],sum-sz[u]);
        if(f[u]<f[root])root=u;
        return ;
    }

//以下需要根据题目调整
    int top,st[maxn];
    void get_data(int u,int fa){        //得到子树信息
       // st[++top]=dis[u];
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
           // dis[v]=dis[u]+w[i];
            get_data(v,u);
        }
        return ;
    }

    void get_ans(){  
        /*
        统计答案
        
        */

        /*
        合并信息

        */
        

        return ;
    }

    void calc(int u){                   //处理经过重心的答案
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
          //  top=0;dis[v]=w[i];
            get_data(v,u);
            get_ans();
        }
        return ;
    }

    void solve(int u){              //分治过程
        vis[u]=1;calc(u);
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
            root=0;f[root]=inf;sum=sz[v];get_gravity(v,u);
            solve(root);
        }
        return ;
    }

    void start(){
        root=0;f[root]=inf;sum=n;
        get_gravity(1,0);
        get_gravity(root,0);
        solve(root);
        return ;
    }

    void init(){
        tot=0;
        memset(head,0,sizeof(int)*(n+2));
        memset(vis,0,sizeof(int)*(n+2));
        return ;
    }

}T;

例题

1. 【模板】点分治1

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
using namespace std;
const int maxn=1e6+101;
const int MOD=998244353;
const int inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int ans[maxn];
int n,m,ask[maxn];
struct Tree_Div{
    int tot,head[maxn],nx[maxn],to[maxn];
    ll w[maxn];
    void add(int x,int y,int z){
        to[++tot]=y;nx[tot]=head[x];head[x]=tot;w[tot]=z;
    }

    bool vis[maxn];             //是否是重心
    int root,sz[maxn],f[maxn];
    int sum;                    //当前子树大小
    void get_gravity(int u,int fa){
        sz[u]=1;f[u]=0;
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            get_gravity(v,u);f[u]=max(f[u],sz[v]);sz[u]+=sz[v];
        }
        f[u]=max(f[u],sum-sz[u]);
        if(f[u]<f[root])root=u;
        return ;
    }

//以下需要根据题目调整
    int top,st[maxn];
    ll dis[maxn];
    void get_data(int u,int fa){        //得到子树信息
        st[++top]=dis[u];
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            dis[v]=dis[u]+w[i];get_data(v,u);
        }
        return ;
    }

    bool book[maxn*10];
    int q[maxn],l;
    void get_ans(){  
        //统计答案
        for(int i=top;i;i--){
            for(int j=1;j<=m;j++){
                if(ask[j]>=st[i] && book[ask[j]-st[i]])ans[j]++;
            }
        }
        //合并信息
        for(int i=top;i;i--){
            if(st[i]<=maxn*10-1){
                q[++l]=st[i];
                book[st[i]]=1;
            }
        }
        return ;
    }

    void calc(int u){                   //处理经过重心的答案
        l=0;book[0]=1;
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
            top=0;dis[v]=w[i];get_data(v,u);
            get_ans();
        }
        for(int i=1;i<=l;i++)book[q[i]]=0;      
        return ;
    }

    void solve(int u){              //分治过程
        vis[u]=1;calc(u);
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
            root=0;f[root]=inf;sum=sz[v];get_gravity(v,u);
            solve(root);
        }
        return ;
    }

}T;
int main(){
    n=read();m=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        T.add(x,y,z);T.add(y,x,z);
    }
    for(int i=1;i<=m;i++)ask[i]=read();
    
    T.sum=n;T.f[T.root]=inf;
    T.get_gravity(1,0);
    int now=T.root;
    T.root=0;T.sum=n;T.f[T.root]=inf;
    T.get_gravity(now,0);   //刚才的root是1子树的重心,可能不是真正的重心,重新找一遍
    T.solve(T.root);

    for(int i=1;i<=m;i++)printf(ans[i]?"AYE\n":"NAY\n");
    return 0;
}
解释为什么刚开始找两遍重心:https://liu-cheng-ao.blog.uoj.ac/blog/2969

2.Tree
给一颗树,每条边有边权,对树上所有长度小于k的路径求和并输出。
为什么说这个题经典呢,因为它套了双指针和容斥。这是在分治中非常非常常见的手法。
以至于你学这个模板的时候你就要会这两个东西。

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
using namespace std;
const int maxn=1e6+101;
const int MOD=998244353;
const int inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int ans;
int n,m;
struct Tree_Div{
    int tot,head[maxn],nx[maxn],to[maxn];
    ll w[maxn];   //边权
    void add(int x,int y,int z){
        to[++tot]=y;nx[tot]=head[x];head[x]=tot;w[tot]=z;
    }

    bool vis[maxn];             //是否是重心
    int root,sz[maxn],f[maxn];
    int sum;                    //当前子树大小
    void get_gravity(int u,int fa){
        sz[u]=1;f[u]=0;
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            get_gravity(v,u);f[u]=max(f[u],sz[v]);sz[u]+=sz[v];
        }
        f[u]=max(f[u],sum-sz[u]);
        if(f[u]<f[root])root=u;
        return ;
    }

//以下需要根据题目调整
    int top,st[maxn];
    ll dis[maxn];
    void get_data(int u,int fa){        //得到子树信息
        st[++top]=dis[u];
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            dis[v]=dis[u]+w[i];
            get_data(v,u);
        }
        return ;
    }

    int get_ans(){  
        sort(st+1,st+top+1);
        int now=0;
        int l=1,r=top;
        while(r>l){
            if(st[r]+st[l]>m)r--;
            else now+=(r-l),l++;
        }
        return now;
    }

    int calc(int u,int len){                   //处理经过重心的答案
        top=0;dis[u]=len;
        get_data(u,0);
        return get_ans();   
    }

    void solve(int u){              //分治过程
        vis[u]=1;ans+=calc(u,0);
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
            ans-=calc(v,w[i]);
            root=0;f[root]=inf;sum=sz[v];
            get_gravity(v,u);solve(root);
        }
        return ;
    }

    void start(){
        root=0;f[root]=inf;sum=n;
        get_gravity(1,0);
        get_gravity(root,0);
        solve(root);
        return ;
    }

    void init(){
        tot=0;
        memset(head,0,sizeof(int)*(n+2));
        memset(vis,0,sizeof(int)*(n+2));
        return ;
    }
}T;
int main(){
    while(1){
        n=read();m=read();
        if(n==0 && m==0)break;
        for(int i=1;i<n;i++){
            int x=read(),y=read(),z=read();
            T.add(x,y,z);T.add(y,x,z);
        }
        T.start();
        printf("%d\n",ans);
        ans=0;
        T.init();
    }
    return 0;
}

3.智乃的树分治(模板)
获取子树dis之外还要记录节点编号
除此之外就是双指针和容斥

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
using namespace std;
const int maxn=1e6+101;
const int MOD=998244353;
const int inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int ans[maxn];
int n,m,k;
struct Tree_Div{
    int tot,head[maxn],nx[maxn],to[maxn];
    ll w[maxn];   //边权
    void add(int x,int y,int z){
        to[++tot]=y;nx[tot]=head[x];head[x]=tot;w[tot]=z;
    }

    bool vis[maxn];             //是否是重心
    int root,sz[maxn],f[maxn];
    int sum;                    //当前子树大小
    void get_gravity(int u,int fa){
        sz[u]=1;f[u]=0;
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            get_gravity(v,u);f[u]=max(f[u],sz[v]);sz[u]+=sz[v];
        }
        f[u]=max(f[u],sum-sz[u]);
        if(f[u]<f[root])root=u;
        return ;
    }

//以下需要根据题目调整
    vector<pa>st;
    ll dis[maxn];
    void get_data(int u,int fa){        //得到子树信息
        st.pb(mp(dis[u],u));
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(v==fa || vis[v])continue;
            dis[v]=dis[u]+w[i];
            get_data(v,u);
        }
        return ;
    }

    void get_ans(int u,int k){  
        int l=1,r=st.size();
        sort(st.begin(),st.end(),[](pa i,pa j){
            return i.fi<j.fi;
        });
        while(r>=l){
            auto i=st[l-1],j=st[r-1];
            if(i.fi+j.fi>m){
                ans[j.se]+=(l-1)*k;
                r--;
            }
            else {
                ans[i.se]+=(r-1)*k;
                l++;
            }
        }
        return ;
    }

    void calc(int u,int len,int k){                   //处理经过重心的答案
        dis[u]=len;st.clear();
        get_data(u,0);
        get_ans(u,k);   
        return ;
    }

    void solve(int u){              //分治过程
        vis[u]=1;calc(u,0,1);
        for(int i=head[u];i;i=nx[i]){
            int v=to[i];if(vis[v])continue;
            calc(v,w[i],-1);
            root=0;f[root]=inf;sum=sz[v];
            get_gravity(v,u);solve(root);
        }
        return ;
    }

    void start(){
        root=0;f[root]=inf;sum=n;
        get_gravity(1,0);
        get_gravity(root,0);
        solve(root);
        return ;
    }

    void init(){
        tot=0;
        memset(head,0,sizeof(int)*(n+2));
        memset(vis,0,sizeof(int)*(n+2));
        return ;
    }
}T;
int main(){
    n=read();m=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=1;
        T.add(x,y,z);T.add(y,x,z);
    }
    T.start();
    for(int i=1;i<=n;i++)printf("%d ",ans[i]+1);    //加上dis(i,i)=0的情况
    return 0;
}

/*
7 2
1 2
2 3
2 4
2 5
5 6
5 7

5 7 5 5 7 4 4 
*/
posted @ 2022-07-20 13:49  I_N_V  阅读(38)  评论(0编辑  收藏  举报