点分治学习笔记
参考蓝书发篇学习笔记。。。
一.算法梗概:
点分治是一种用于在一棵树上,无对路劲进行修改的操作,对某些具有限定条件的路径进行静态统计的算法。
点分治一般用来处理无根树,我们可以随意认定根节点。
二.实现过程:
我们拿一道例题来说一下:
P4178 Tree
我们认定根节点为 \(root\),那么对于 \(root\) 而言,树上的路径有两种:
1.经过 \(root\) 的路径;
2.不经过 \(root\) 但包含在 \(root\) 的某棵子树内。
对于路径种类1,我们可以从 \(root\) 点出发,对整棵树进行 \(\text{dfs}\),求出点 \(i\) 到 \(root\) 的距离 \(dis_i\),同时可以求出 \(b_i\),表示 \(i\) 属于 \(root\) 的哪一棵子树。特别地,\(b_{root}=root\)。
代码如下:
点击查看代码
inline void getdis(int x,int fa,int d,int from)
{
a[++now]=x,b[x]=from,dis[x]=d;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(y==fa || vis[y]) continue;
getdis(y,x,d+z,from);
}
return;
}
而我们要统计的,就是满足如下所有条件的点对 \((x,y)\) 的数量:(1)\(b_x\ne b_y\);(2)\(dis_x+dis_y\leqslant k\)
对于路径种类2,我们可以分治一下,将 \(root\) 的每棵子树递归处理。
那么我们最常见的 \(calc\) 函数的写法,就是指针扫描数组的方法:
将树上每个节点放到一个数组 \(a\) 里去,然后按照节点的 \(dis\) 值排序。显然,\(l\) 在向右扫描的过程中,恰好使得 \(d_{a_l}+d_{a_r}\leqslant k\) 的 \(r\) 是从右向左单调递减。那么我们用 \(cnt_s\) 来统计 \(l+1\sim r\) 之间满足 \(b_{a_i}\) 的 \(i\) 的个数,那么,当某条路径的某一端为 \(a_l\) 时,另一端的合法的个数就为 \(r-l-cnt_{b_{a_l}}\)。
代码如下:
点击查看代码
inline int calc(int x)
{
tot=0,a[++tot]=x,b[x]=x,dis[x]=0,now[b[x]]=1;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(vis[y]) continue;
getdis(y,x,z,y);
}
sort(a+1,a+tot+1,cmp);
int l=1,r=tot,res=0;
while(l<=r)//一定要注意,一定不要写成l<r,因为这样会导致l=r时直接退出,但是有一个没有减掉
{
while(l<r && dis[a[l]]+dis[a[r]]<=k)
{
res+=r-l+1-now[b[a[l]]];
now[b[a[l]]]--;l++;
}
now[b[a[r]]]--;r--;
}
return res;
}
若递归的深度为 \(dep\),那么算法的时间复杂度就为 \(\mathcal{O}(dep· n\log n)\)
但是我们想一种情况,若树的形态为一条链,那么最坏情况下,每次根都选到链的端点,那么递归深度就需要 \(n\) 层,算法时间复杂度就退化成 \(\mathcal{O}(n^2\log n)\)。所以,我们要对根的选择进行一个优化,每次都找到树的重心作为根节点。
代码如下:
点击查看代码
inline void getrt(int x,int fa,int tot)
{
siz[x]=1,hson[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa || vis[y]) continue;
getrt(y,x,tot);
siz[x]+=siz[y];
hson[x]=max(hson[x],siz[y]);
}
hson[x]=max(hson[x],tot-siz[x]);
if(!root || hson[x]<hson[root]) root=x;
return;
}
解释:因为此时 \(root\) 的每棵子树的大小都不会超过整棵树的一半,那么就限制了递归层数最多为 \(\mathcal{O}(\log n)\),那么现在算法的时间复杂度就变成了 \(\mathcal{O}(n\log^2n)\)。
完整代码:
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=4e4+5;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*f;
}
int n,k;
struct edge
{
int to,nxt,len;
}e[MAXN<<1];
int head[MAXN],cnt;
inline void add(int x,int y,int z)
{
e[++cnt].to=y;
e[cnt].len=z;
e[cnt].nxt=head[x];
head[x]=cnt;
return;
}
int root,tot;
int siz[MAXN],hson[MAXN];
int dis[MAXN],a[MAXN],b[MAXN];
bool vis[MAXN];
int now[MAXN];
inline void getrt(int x,int fa,int tot)
{
siz[x]=1,hson[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa || vis[y]) continue;
getrt(y,x,tot);
siz[x]+=siz[y];
hson[x]=max(hson[x],siz[y]);
}
hson[x]=max(hson[x],tot-siz[x]);
if(!root || hson[x]<hson[root]) root=x;
return;
}
inline void getdis(int x,int fa,int d,int from)
{
a[++tot]=x,b[x]=from,dis[x]=d,now[b[x]]++;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(y==fa || vis[y]) continue;
getdis(y,x,d+z,from);
}
return;
}
inline bool cmp(int a,int b) {return dis[a]<dis[b];}
int ans;
inline int calc(int x)
{
tot=0,a[++tot]=x,b[x]=x,dis[x]=0,now[b[x]]=1;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(vis[y]) continue;
getdis(y,x,z,y);
}
sort(a+1,a+tot+1,cmp);
int l=1,r=tot,res=0;
while(l<=r)
{
while(l<r && dis[a[l]]+dis[a[r]]<=k)
{
res+=r-l+1-now[b[a[l]]];
now[b[a[l]]]--;l++;
}
now[b[a[r]]]--;r--;
}
return res;
}
inline void solve(int x)
{
vis[x]=true;ans+=calc(x);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]) continue;
root=0;
getrt(y,0,siz[y]);
solve(root);
}
return;
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}
k=read();
hson[0]=n,getrt(1,0,n);solve(root);
printf("%lld\n",ans);
return 0;
}
典型例题
例一 P3806 【模板】点分治1
纯纯的板子,只是从统计数量变成了是否存在的问题。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e4+5;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*f;
}
struct edge
{
int to,nxt,len;
}e[MAXN<<1];
int head[MAXN],cnt;
inline void add(int x,int y,int z)
{
e[++cnt].to=y;
e[cnt].len=z;
e[cnt].nxt=head[x];
head[x]=cnt;
return;
}
int n,m,root;
int siz[MAXN],hson[MAXN];
bool vis[MAXN],flag[MAXN];
int a[MAXN],b[MAXN],dis[MAXN],tot;
inline void getrt(int x,int fa,int tot)
{
siz[x]=1,hson[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa || vis[y]) continue;
getrt(y,x,tot);
siz[x]+=siz[y];
hson[x]=max(siz[y],hson[x]);
}
hson[x]=max(hson[x],tot-siz[x]);
if(!root || hson[x]<hson[root]) root=x;
return;
}
inline void getdis(int x,int fa,int d,int from)
{
a[++tot]=x,b[x]=from,dis[x]=d;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(y==fa || vis[y]) continue;
getdis(y,x,d+z,from);
}
return;
}
inline bool cmp(int a,int b)
{
return dis[a]<dis[b];
}
int ask[MAXN];
inline void calc(int x)
{
tot=0,a[++tot]=x,b[x]=x,dis[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,z=e[i].len;
if(vis[y]) continue;
getdis(y,x,z,y);
}
sort(a+1,a+tot+1,cmp);
for(int i=1;i<=m;i++)
{
int l=1,r=tot;
if(flag[i]) continue;
while(l<r)
{
if(dis[a[l]]+dis[a[r]]>ask[i]) r--;
else if(dis[a[l]]+dis[a[r]]<ask[i]) l++;
else if(b[a[l]]==b[a[r]])
{
if(dis[a[r]]==dis[a[r-1]]) r--;
else l++;
}
else {flag[i]=true;break;}
}
}
}
inline void solve(int x)
{
vis[x]=true;calc(x);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]) continue;
root=0;
getrt(y,0,siz[y]);
solve(root);
}
return;
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n-1;i++)
{
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}
for(int i=1;i<=m;i++)
{
ask[i]=read();
if(!ask[i]) flag[i]=true;
}
hson[0]=n;getrt(1,0,n);solve(root);
for(int i=1;i<=m;i++)
{
if(flag[i]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}