点分治

 点分治

点分治是树上分治的一种(树上分治还有边分治),常用于解决和树上路径有关的问题。

因为树上路径有一条性质:树上的任何路径,要么经过根节点$rt$要么就全部在$rt$的一颗子树上。

正确性显而易见:树上两点的路径是唯一的,如果两点在$rt$的同一子树上,则路径完全在一颗子树上,如果在$rt$的不同子树,则必然经过$rt$。

有了这条性质,我们就可以对树上的路径进行分治:先将经过$rt$的路径处理完,此时$rt$的各个子树上的路径就互不影响了,故可以递归分治。

但怎么使得分治均匀呢?如果随意选择根节点,则如果树退化成链,则递归层数为$N$,且每次操作的节点数目都会非常多。

所以此时,我们选择树的重心,这样可以使每一次分治后剩下的最大子树的大小下降最快(重心的最大子树最小),每一次分治我们都找重心,操作本身复杂度为$O(NlogN)$,可以使分治底层执行次数降到$O(NlogN)$(主定理)。通常情况下,节点的子树超过$2$棵,则复杂度往往会低于$O(NlogN)$

void getrt(int x,int fa){
    siz[x]=1;
    maxs[x]=0;
    for(int i=head[x];i;i=nxt[i]){
        if(fa==to[i]&&vis[to[i]])
            continue;
        getrt(to[i],x);
        siz[x]+=siz[to[i]];
        maxs[x]=max(maxs[x],siz[to[i]]);
    }
    maxs[x]=max(maxs[x],sum-siz[x]);
    if(maxs[x]<maxs[rt])
        rt=x;
}
求树的重心代码(顺便更新根节点)

这有什么用呢?

举个例子,如果我们要统计一棵树内各个长度的路径的个数,就可以用点分治来做。

首先,我们将树的重心设为根,求出树上各点到根的距离($O(N)$),并统计每个长度的路径的数量($O(N^2)$),然后枚举每一个子树,递归执行同样的操作,并在同一个数组上统计数量

算法的核心在于分治:

就着代码讲:

void Divid(int x)
{
    ans+=solve(x,0);  ①统计经过RT的所有路径(不一定合法,下文会重点讲)
    vis[x] = 1;
    for (int i = head[x];i;i = edges[i].net)  枚举所有子树
    {
        edge v = edges[i];
        if(vis[v.to]) continue;  防止遍历到上层
        ans-=solve(v.to,edges[i].cost);  减掉①中统计的不合法路径*
        S = size[v.to]; root = 0;
        find(v.to,x);  找到新根
        Divid(root);  按子树分治
    }
}

 

*上面代码中提到求出的“经过RT的路径”不一定合法,要减掉一部分,为什么呢?在统计经过rt的路径时,我们将所有点都两两匹配了,此时同一棵子树中的点当然不能配对。

如图,我们将$RT \rightarrow B$和$RT \rightarrow C$的路径合并了,其中$RT \rightarrow A$部分的路径重复了

怎么去掉重复的部分呢?见代码,我们【将“经过A点的路径”(不一定合法)加上$RT \rightarrow A$的长度】这一部分产生的贡献去掉

如图,过A点的路径为$B \rightarrow A \rightarrow C$,统计结果为$A \rightarrow B+A \rightarrow C+2*(RT \rightarrow A)$(因为所有长度都加了$RT \rightarrow A$)

这样,我们就求出了所有经过RT的路径的贡献

 

接下来,就是逐步细化实现了。对于大部分点分治的题目,上面分治的代码都是差不多的,题目的差异在于solve函数。

下面,我以洛谷P3806 【模板】点分治1为例介绍一下实现的过程 题解

此题的题意很简单,就是询问树上长度为k的路径是否存在。看到询问路径,就很容易想到点分治,故我们可以套上面的模板。

这里介绍一下solve函数的实现

首先,我们要求出从当前根节点到子数中所有点的距离,然后将这些距离组合配对,将所有的和统计到答案中即可

void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/
    tp=0;
    dis[x]=len;
    get_dis(x,0,len);
    for(int i=1;i<=tp;i++)
        for(int j=1;j<=tp;j++)
            if(i!=j)
                ans[st[i]+st[j]]+=w;
}

 

这是求距离的代码

void get_dis(int x,int fa,int len){
    if(len<=1e7)
        st[++tp]=len;
    for(int i=head[x];i;i=nxt[i]){
        if(to[i]==fa||vis[to[i]])
            continue;
        dis[to[i]]=len+val[i];
        get_dis(to[i],x,len+val[i]);
    }
}

 

于是,我们能得到以下代码,可以AC(是因为数据水)

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long LL;
 4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7;
 5 inline void Max(int &x,int y){
 6     x=x>y?x:y;
 7 }
 8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM];
 9 inline void add(int x,int y,int z){
10     nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z;
11     nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z;
12 }
13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S;
14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/
15     siz[x]=1;
16     maxson[x]=0;
17     for(int i=head[x];i;i=nxt[i]){
18         if(to[i]==fa||vis[to[i]])
19             continue;
20         find(to[i],x);
21         siz[x]+=siz[to[i]];
22         Max(maxson[x],siz[to[i]]);
23     }
24     Max(maxson[x],S-siz[x]);
25     if(maxson[x]<maxson[rt])
26         rt=x;
27 }
28 int dis[MAXN],st[MAXN],tp;
29 void get_dis(int x,int fa,int len){
30     if(len<=1e7)
31         st[++tp]=len;
32     for(int i=head[x];i;i=nxt[i]){
33         if(to[i]==fa||vis[to[i]])
34             continue;
35         dis[to[i]]=len+val[i];
36         get_dis(to[i],x,len+val[i]);
37     }
38 }
39 int ans[MAXK];
40 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/
41     tp=0;
42     dis[x]=len;
43     get_dis(x,0,len);
44     for(int i=1;i<=tp;i++)
45         for(int j=1;j<=tp;j++)
46             if(i!=j)
47                 ans[st[i]+st[j]]+=w;
48 }
49 int N,Q,K;
50 void divide(int x){
51     solve(x,0,1);
52     vis[x]=1;
53     for(int i=head[x];i;i=nxt[i]){
54         if(vis[to[i]])
55             continue;
56         solve(to[i],val[i],-1);
57         S=siz[x];
58         rt=0;
59         maxson[0]=N;
60         find(to[i],x);
61         divide(rt);
62     }
63 }
64 int main(){
65     scanf("%d%d",&N,&Q);
66     for(int i=1;i<N;i++){
67         int ii,jj,kk;
68         scanf("%d%d%d",&ii,&jj,&kk);
69         add(ii,jj,kk);
70     }
71     S=N;
72     maxson[0]=N;
73     rt=0;
74     find(1,0);
75     divide(rt);
76     while(Q--){
77         scanf("%d",&K);
78         puts(ans[K]?"AYE":"NAY");
79     }
80     return 0;
81 }
View Code

 

分析代码可以发现,实际上代码的时间复杂度为$\Theta(N^2 log N)$,在较强的数据中是会TLE的,于是我们要优化

我们发现我们对于所有可能的询问,都统计了答案,这其实是一种冗余。题目中的m非常小,我们其实可以根据询问来统计答案,效率会提高两个数量级

于是,我们很容易想到将询问离线,然后每次在表内统计一份结果,并且枚举所有的询问,在表内查询之前是否得到过$答案-当前结果$的值

新的solve函数

void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/
    ++timeclock;
    tp=0;
    dis[x]=len;
    get_dis(x,0,len);
    for(int i=1;i<=tp;i++)
        for(int j=1;j<=Q;j++){
            int ii=qry[j]-st[i];
            if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i]))
                continue;
            ans[j]+=w;
        }
}

 

最后的代码也就很简单了,用时为原来的几十分之一

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long LL;
 4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7,MAXQ=1e2+7;
 5 inline void Max(int &x,int y){
 6     x=x>y?x:y;
 7 }
 8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM];
 9 inline void add(int x,int y,int z){
10     nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z;
11     nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z;
12 }
13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S;
14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/
15     siz[x]=1;
16     maxson[x]=0;
17     for(int i=head[x];i;i=nxt[i]){
18         if(to[i]==fa||vis[to[i]])
19             continue;
20         find(to[i],x);
21         siz[x]+=siz[to[i]];
22         Max(maxson[x],siz[to[i]]);
23     }
24     Max(maxson[x],S-siz[x]);
25     if(maxson[x]<maxson[rt])
26         rt=x;
27 }
28 int dis[MAXN],st[MAXN],tp;
29 int qry[MAXQ],ans[MAXQ],date[MAXK],b[MAXK],timeclock;
30 int N,Q,K;
31 void get_dis(int x,int fa,int len){
32     if(len<=1e7){
33         st[++tp]=len;
34         if(date[len]==timeclock)
35             b[len]++;
36         else{
37             b[len]=1;
38             date[len]=timeclock;
39         }
40     }
41     for(int i=head[x];i;i=nxt[i]){
42         if(to[i]==fa||vis[to[i]])
43             continue;
44         dis[to[i]]=len+val[i];
45         get_dis(to[i],x,len+val[i]);
46     }
47 }
48 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/
49     ++timeclock;
50     tp=0;
51     dis[x]=len;
52     get_dis(x,0,len);
53     for(int i=1;i<=tp;i++)
54         for(int j=1;j<=Q;j++){
55             int ii=qry[j]-st[i];
56             if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i]))
57                 continue;
58             ans[j]+=w;
59         }
60 }
61 void divide(int x){
62     solve(x,0,1);
63     vis[x]=1;
64     for(int i=head[x];i;i=nxt[i]){
65         if(vis[to[i]])
66             continue;
67         solve(to[i],val[i],-1);
68         S=siz[x];
69         rt=0;
70         maxson[0]=N;
71         find(to[i],x);
72         divide(rt);
73     }
74 }
75 int main(){
76     scanf("%d%d",&N,&Q);
77     for(int i=1;i<N;i++){
78         int ii,jj,kk;
79         scanf("%d%d%d",&ii,&jj,&kk);
80         add(ii,jj,kk);
81     }
82     for(int i=1;i<=Q;i++){
83         scanf("%d",qry+i);
84     }
85     S=N;
86     maxson[0]=N;
87     rt=0;
88     find(1,0);
89     divide(rt);
90     for(int i=1;i<=Q;i++)
91         puts(ans[i]?"AYE":"NAY");
92     return 0;
93 }
View Code

 

posted @ 2019-06-09 21:46  guoshaoyang  阅读(677)  评论(0编辑  收藏  举报