点分治浅谈
树上问题入门——点分治
1 何为点分治
树上有三种分治,点分治,边分支,链分治,第三者就是我们通常所说树链剖分。
点分治是一种解决树上统计问题的常用方法,本质思想是选择一点作为分治中心,将原问题划分成几个相同的子树上的问题进行递归解决。
题目中常见给出的是无根树,即所需维护的信息与谁是根无关。
常见的是统计有关树上路径的问题。
2 讲解
分治算法都偏向于应用。我们来一到例题。
我们先把所有询问离线下来。
很显然,我们取一点为根。那么路径可以被分为两种,一种是经过根节点的路径,另一种是不经过根节点的路径。
那么我们就先处理经过根节点的路径,剩下的递归去处理,这就把原问题划分成了若干个子问题。
那么我们选哪个点作为根节点呢?如果是随便一个点的话,我们可能会取到一条链上的端点,这样的递归层数是 \(O(n)\) 的。但如果我们每一次取树的重心的话,递归层数就变成了 \(O(\log n)\) ,这给我们留下来了处理每一层信息的时间。
一般来说,如果处理每一层信息的时间复杂度是 \(O(m)\) 的话,那么整个点分治的时间复杂度为 \(O(m\log n)\) 。
2.1 求重心
由于是建的双向边,所以注意判断父节点。
inline void get_zhongxin(int k,int fa){
size[k]=1;max_size[k]=-INF;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
get_zhongxin(to,k);size[k]+=size[to];
max_size[k]=Max(max_size[k],size[to]);
}
max_size[k]=Max(max_size[k],tree_sum-size[k]);
if(max_size[k]<max_size[root]) root=k;
}
注意,这里比平常的判断还多了一个条件:vis[to]
,这是为什么呢?这里 vis
表示的是这个节点是否被当做某一颗子树的重心。举个例子:
我们会先取 \(1\) 为重心,然后左子树上会取 \(4\) 为重心,然后我们会取 \(2,6\) 为重心。注意到 \(2\) 为 \(4\) 的父亲,所以如果不加上 vis
会有问题。
当我们每一次取重心时,整个树的结构实际上被改变了,这棵树被称作点分树。
2.2 题目讲解
那么剩下的就是关系到这个题的特殊性而非一般的点分治算法了。
这里我们在求完重心之后,算一遍其子树到重心的距离,子树之间是没有顺序的,所以我们在求完一颗子树后,我们检查一下每一个询问,看一看是否存在有其它一颗子树上的点到重心的路径与当前路径加起来正好等于询问。
这段代码如下:
l=r=0;q[++r]=0;
dist_appeared[0]=1;vis[k]=1;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;dist_tail=0;
if(to==fa||vis[to]) continue;
dist[to]=li[x].w;get_dist(to,-1);
for(int i=1;i<=dist_tail;i++){
for(int j=1;j<=m;j++){
if(ques[j]>=dq[i]) ans[j]|=dist_appeared[ques[j]-dq[i]];
}
}
for(int i=1;i<=dist_tail;i++){
if(dq[i]>=10000010) continue;
q[++r]=dq[i];dist_appeared[dq[i]]=1;
}
}
while(l<r) dist_appeared[q[l+1]]=0,l++;
注意这里消除影响的时候我们需要开一个队列来保证复杂度,不能直接用 memset
。
这里 getlist
的代码是:
inline void get_dist(int k,int fa){
dq[++dist_tail]=dist[k];
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;//
dist[to]=dist[k]+li[x].w;
get_dist(to,k);
}
}
所以不要忘了提前预处理。
做完后我们直接求一下子树的重心,点分治即可。
总代码:
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 20010
#define M 110
using namespace std;
const int INF=0x3f3f3f3f;
template<typename T> inline T Max(T a,T b){
return a>b?a:b;
}
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
bool dist_appeared[10000010],vis[N],ans[M];
int n,m,ques[M],size[N],max_size[N],tree_sum;
int l,r,q[N],dist[N],dq[N],dist_tail,root;
struct edge{
int to,next,w;
inline void intt(int to_,int ne_,int w_){
to=to_;next=ne_;w=w_;
}
};
edge li[N];
int head[N],tail;
inline void add(int from,int to,int w){
li[++tail].intt(to,head[from],w);
head[from]=tail;
}
inline void get_zhongxin(int k,int fa){
size[k]=1;max_size[k]=-INF;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
get_zhongxin(to,k);size[k]+=size[to];
max_size[k]=Max(max_size[k],size[to]);
}
max_size[k]=Max(max_size[k],tree_sum-size[k]);
if(max_size[k]<max_size[root]) root=k;
}
inline void get_dist(int k,int fa){
dq[++dist_tail]=dist[k];
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;//
dist[to]=dist[k]+li[x].w;
get_dist(to,k);
}
}
inline void dfz(int k,int fa){
l=r=0;q[++r]=0;
dist_appeared[0]=1;vis[k]=1;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;dist_tail=0;
if(to==fa||vis[to]) continue;
dist[to]=li[x].w;get_dist(to,-1);
for(int i=1;i<=dist_tail;i++){
for(int j=1;j<=m;j++){
if(ques[j]>=dq[i]) ans[j]|=dist_appeared[ques[j]-dq[i]];
}
}
for(int i=1;i<=dist_tail;i++){
if(dq[i]>=10000010) continue;
q[++r]=dq[i];dist_appeared[dq[i]]=1;
}
}
while(l<r) dist_appeared[q[l+1]]=0,l++;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
root=0;max_size[0]=INF;tree_sum=size[to];//
get_zhongxin(to,-1);get_zhongxin(root,-1);
dfz(root,k);
}
}
int main(){
read(n);read(m);
for(int i=1;i<=n-1;i++){
int from,to,w;read(from);read(to);read(w);
add(from,to,w);add(to,from,w);
}
for(int i=1;i<=m;i++) read(ques[i]);
tree_sum=n;root=0;max_size[0]=INF;get_zhongxin(1,-1);get_zhongxin(root,-1);
dfz(root,-1);
for(int i=1;i<=m;i++) if(ans[i]) printf("AYE\n");else printf("NAY\n");
return 0;
}
2.3 注意事项
实现起来有很多细节,所以一定要自己尝试,并多刷例题。
在找重心的时候注意我们要找两边重心,第二遍是因为要保证 \(size\) 的正确性, 不过据说也不用保证,证明在这里:链接
亲测有时快,有时慢,但都没有 TLE ,就很玄学。
3 例题
3.1 P4178 Tree
点分治本身不难想到,我们来想一下如何统计一层上的信息。
我们把所有子树中的节点到根的距离算出来,同时记录一下每一个节点属于哪一个儿子,特别的,根节点属于自己。然后我们按照距离从小到大排序。
考虑左右指针 \(l,r\) ,如果有 \(d_r+d_l>k\) 我们就向左移动右指针,否则,我们累加 \(r-l-cnt_{l}\) ,其中 \(cnt_l\) 表示在 \((l,r]\) 中与 \(l\) 子树相同的树的个数。
不难证明上述算法的正确性:其中 \(r-l\) 中所有与 \(l\) 不在同一颗子树上的节点是需要算进答案里的。显然上面算法是正确的。不过排序后利用单调性很难想到,双指针的题做的还是少。其实双指针就是利用的单调性来累加答案。
需要注意的是,\(cnt\) 数组在双指针扫的时候是一起对应修改的,但是扫完后由于一些不知道的原因并不是全都是 \(0\) ,所以我们需要给数组清零。
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 40100
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
inline int Max(int a,int b){
return a>b?a:b;
}
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
bool vis[N];
int root,size[N],maxsize[N],treesum,ans;
int dist[N],ftail,cnt[N],K;
struct node{
int belong,val;
inline node(){}
inline node(int belong,int val) : belong(belong),val(val) {}
inline bool operator < (const node &b) const {
return val<b.val;
}
};
node f[N];
struct edge{
int to,next,w;
inline void intt(int to_,int ne_,int w_){
to=to_;next=ne_;w=w_;
}
};
edge li[N<<1];
int head[N],tail;
inline void add(int from,int to,int w){
li[++tail].intt(to,head[from],w);
head[from]=tail;
}
inline void getzhongxin(int k,int fa){
size[k]=1;maxsize[k]=-INF;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
getzhongxin(to,k);
size[k]+=size[to];maxsize[k]=Max(maxsize[k],size[to]);
}
maxsize[k]=Max(maxsize[k],treesum-size[k]);
if(maxsize[k]<maxsize[root]) root=k;
}
inline void solvezhongxin(int k){
root=0;maxsize[0]=INF;
getzhongxin(k,-1);getzhongxin(root,-1);
// printf("nowroot:%d\n",root);
}
inline void getdist(int k,int fa,int belong){
f[++ftail]=node(belong,dist[k]);cnt[belong]++;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
dist[to]=dist[k]+li[x].w;
getdist(to,k,belong);
}
}
inline void compeat(int l,int r){
while(l<r){
while(f[r].val+f[l].val>K) cnt[f[r].belong]--,r--;
ans+=r-l-cnt[f[l].belong];
l++;cnt[f[l].belong]--;
}
// printf("nowans:%d\n",ans);
}
inline void dfz(int k,int fa){
vis[k]=1;ftail=0;f[++ftail]=node(k,0);dist[k]=0;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
dist[to]=li[x].w;getdist(to,-1,to);
}
sort(f+1,f+ftail+1);int l=1,r=ftail;compeat(l,r);
// for(int i=1;i<=ftail;i++){
// printf("f:val:%d\n",f[i].val);
// printf("f:belong%d\n",f[i].belong);
// }
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
cnt[to]=0;
}
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
treesum=size[to];solvezhongxin(to);
dfz(root,k);
}
}
int n,k;
int main(){
// freopen("my.out","w",stdout);
read(n);
for(int i=1;i<=n-1;i++){
int from,to,w;read(from);read(to);read(w);
add(from,to,w);add(to,from,w);
}
read(K);
treesum=n;solvezhongxin(1);dfz(root,-1);
printf("%d",ans);
return 0;
}
时间复杂度 \(O(n\log n)\)
3.2 P2634 [国家集训队]聪聪可可
不难想到点分治,也不难想到如何统计一层上的信息。需要注意的是这个是有序数对。在下面的程序中,我们统计的是无序数对且并没有考虑兄弟二人选的是同一点的情况,这样比较好统计一点,我们这些信息最后在累加即可。
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 40100
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
inline int Max(int a,int b){
return a>b?a:b;
}
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
bool vis[N];
int root,size[N],maxsize[N],treesum,ans;
int dist[N],ftail,cnt[N],K;
struct node{
int belong,val;
inline node(){}
inline node(int belong,int val) : belong(belong),val(val) {}
inline bool operator < (const node &b) const {
return val<b.val;
}
};
node f[N];
struct edge{
int to,next,w;
inline void intt(int to_,int ne_,int w_){
to=to_;next=ne_;w=w_;
}
};
edge li[N<<1];
int head[N],tail;
inline void add(int from,int to,int w){
li[++tail].intt(to,head[from],w);
head[from]=tail;
}
inline void getzhongxin(int k,int fa){
size[k]=1;maxsize[k]=-INF;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
getzhongxin(to,k);
size[k]+=size[to];maxsize[k]=Max(maxsize[k],size[to]);
}
maxsize[k]=Max(maxsize[k],treesum-size[k]);
if(maxsize[k]<maxsize[root]) root=k;
}
inline void solvezhongxin(int k){
root=0;maxsize[0]=INF;
getzhongxin(k,-1);getzhongxin(root,-1);
// printf("nowroot:%d\n",root);
}
inline void getdist(int k,int fa,int belong){
f[++ftail]=node(belong,dist[k]);cnt[belong]++;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
dist[to]=dist[k]+li[x].w;
getdist(to,k,belong);
}
}
inline void compeat(int l,int r){
while(l<r){
while(f[r].val+f[l].val>K) cnt[f[r].belong]--,r--;
ans+=r-l-cnt[f[l].belong];
l++;cnt[f[l].belong]--;
}
// printf("nowans:%d\n",ans);
}
inline void dfz(int k,int fa){
vis[k]=1;ftail=0;f[++ftail]=node(k,0);dist[k]=0;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
dist[to]=li[x].w;getdist(to,-1,to);
}
sort(f+1,f+ftail+1);int l=1,r=ftail;compeat(l,r);
// for(int i=1;i<=ftail;i++){
// printf("f:val:%d\n",f[i].val);
// printf("f:belong%d\n",f[i].belong);
// }
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
cnt[to]=0;
}
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa||vis[to]) continue;
treesum=size[to];solvezhongxin(to);
dfz(root,k);
}
}
int n,k;
int main(){
// freopen("my.out","w",stdout);
read(n);
for(int i=1;i<=n-1;i++){
int from,to,w;read(from);read(to);read(w);
add(from,to,w);add(to,from,w);
}
read(K);
treesum=n;solvezhongxin(1);dfz(root,-1);
printf("%d",ans);
return 0;
}