点分治
终于是对点分治理解稍微深入一些了
一般而言,淀粉质点分治可以用于解决树上符合条件的路径的数量统计一类的问题
主要是一种思想,不太好干写,干脆就用一道题来当例子吧
POJ 1741 Tree
个人认为这是一道比较适合理解点分治思想的题目
TLDR
有若干组数据,每组数据输入一棵树和一个值\(k\),求树上距离小于等于\(k\)的点对有多少对
Solution
\(O(n^3)\)和\(O(n^2)\)的思路非常显然,然而我们当然不满足于此
此题运用点分治可以做到复杂度为\(O(n\space log\space n)\)
点分治是怎么做的?
在一棵无根树中,容易发现,对于树上的一个点\(u\),经过\(u\)的路径不外乎有两种情况:
- \(u\)是路径的一个端点
- \(u\)是路径中间的某一个点
如果我们钦定\(u\)为根,要求经过它的路径的长度,对于第一类路径,直接从根往下递归就可以了
而对于第二类路径,则可以将路径看成是由来自根的两棵子树的两条路径拼接在一起
其实就是两条第一类路径拼在一起
要让递归的层数最小,我们在选根的时候,要使深度最大的子树的深度尽可能小
是为求树的重心
这样,我们的递归层数就能够保持在\(log\space n\)级别
要得到所有的路径,只讨论一个点当然不行,所以我们实际上要把树上所有的点都当做根讨论一遍
\(O(n^2)\)?
非也非也
当我们讨论完一个点之后,下一个要讨论的点肯定在根的子树里,那我们可以这么做:
每次讨论完一个根\(u\)之后,都在这个根的所有子树中再找一个重心\(v\),把它当做这个子树的根,再次进行讨论
重复这个过程,就相当于将树层层分割,是为分治
如果这么做,那么我们在子树中对\(v\)进行讨论的时候,就不需要再将经过\(u\)的路径纳入讨论范围了
有点像把子树从整棵树中断开了,只是并没有真的这么做罢了
这就是我把子树加粗的原因
所以,对于每个点,我们都要讨论一次,而每次讨论的递归层数都是\(log\space n\)级别的,所以点分治的时间复杂度是\(O(n\space log\space n)\)
这就是点分治,以点作为分治的依据,“自顶向下”进行分治
回到这道题,根据点分治的思想,在讨论一个点的时候,我们可以递归求出所有第一类路径的长度并将它们记录下来,然后我们可以将所有的这些第一类路径拼起来,再统计答案即可
统计答案的话sort一下就可以,具体见代码。
平衡树当然也可以用
Code
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <iostream>
#define maxn 10005
using namespace std;
const int inf=0x7f7f7f7f;
int n,K;
int sum,mxson[maxn],siz[maxn],dis[maxn],dep[maxn],root,ans,vst[maxn];
struct edge
{
int u,v,w,nxt;
}g[maxn*2];
int head[maxn],ecnt;
void eADD(int u,int v,int w)
{
g[++ecnt].u=u;
g[ecnt].v=v;
g[ecnt].w=w;
g[ecnt].nxt=head[u];
head[u]=ecnt;
}
void getroot(int u,int fa)//求重心
{
siz[u]=1,mxson[u]=0;
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(v==fa||vst[v])continue;
getroot(v,u);
siz[u]+=siz[v];
mxson[u]=max(mxson[u],siz[v]);
}
mxson[u]=max(mxson[u],sum-siz[u]);
if(mxson[u]<mxson[root])root=u;
}
void getdis(int u,int fa)//递归求长度
{
dis[++dis[0]]=dep[u];
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(v==fa||vst[v])continue;
dep[v]=dep[u]+g[i].w;
getdis(v,u);
}
}
int calc(int u,int now)//计算答案
{
dep[u]=now,dis[0]=0;
getdis(u,0);
sort(dis+1,dis+dis[0]+1);
int re=0;
for(register int l=1,r=dis[0];l<r;)
{
if(dis[l]+dis[r]<=K)re+=r-l,++l;
else --r;
}
return re;
}
void solve(int u)//分治
{
ans+=calc(u,0);
vst[u]=1;//标记已经作为根讨论过
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(vst[v])continue;
ans-=calc(v,g[i].w);
sum=siz[v];
root=0;
getroot(v,0);
solve(root);
}
}
void Inti()
{
memset(g,0,sizeof(g));
memset(head,0,sizeof(head));
memset(vst,0,sizeof(vst));
ecnt=ans=0;
}
int main()
{
scanf("%d%d",&n,&K);
while(n)
{
Inti();
for(register int i=1;i<n;++i)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
eADD(u,v,w),eADD(v,u,w);
}
sum=n;mxson[0]=inf;root=0;
getroot(1,0);
solve(root);
printf("%d\n",ans);
scanf("%d%d",&n,&K);
}
return 0;
}
Some details
等下,这行东西是在干嘛
ans-=calc(v,g[i].w);
还记得我在上文中把两棵子树加粗了吗
如果直接从根开始递归,所得到的第一类路径可能来自于同一棵子树,甚至有一大堆重边,所以我们还要从每个子树的根(不是重心)出发再做一遍calc,并且减去这些情况
这步操作很重要,这样做的原因务必需要理解
Luogu P3806 【模板】点分治1
Solution
差不多也是套框架,我采用的方法是开个桶,在求得dis数组之后二分看点是否存在
lower_bound的上下界需要留意一下不要问我怎么知道的
Code
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <cmath>
#define maxn 10005
#define maxm 105
using namespace std;
typedef long long ll;
const int inf=0x7f7f7f7f;
int n,m;
int K[maxm],cnt[maxm];
int mxsiz[maxn],dep[maxn],dis[maxn],vst[maxn],siz[maxn],root,sum;
struct edge
{
int u,v,w,nxt;
}g[maxn*2];
int head[maxn],ecnt;
void eADD(int u,int v,int w)
{
g[++ecnt].u=u;
g[ecnt].v=v;
g[ecnt].w=w;
g[ecnt].nxt=head[u];
head[u]=ecnt;
}
void getroot(int u,int fa)
{
siz[u]=1;mxsiz[u]=0;
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(v==fa||vst[v])continue;
getroot(v,u);
siz[u]+=siz[v];
mxsiz[u]=max(mxsiz[u],siz[v]);
}
mxsiz[u]=max(mxsiz[u],sum-siz[u]);
if(mxsiz[u]<mxsiz[root])root=u;
}
void getdep(int u,int fa)
{
dis[++dis[0]]=dep[u];
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(vst[v]||v==fa)continue;
dep[v]=dep[u]+g[i].w;
getdep(v,u);
}
}
void calc(int u,int now,int type)
{
dep[u]=now,dis[0]=0;
getdep(u,0);
if(dis[0]==1)
return;
sort(dis+1,dis+dis[0]+1);
for(register int i=1;i<=dis[0];++i)
for(register int j=1;j<=m;++j)
if(*(lower_bound(dis+i+1,dis+dis[0]+1,K[j]-dis[i]))==K[j]-dis[i])
cnt[j]+=type;
}
void solve(int u)
{
calc(u,0,1);
vst[u]=1;
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(vst[v])continue;
calc(v,g[i].w,-1);
sum=siz[v];
root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
scanf("%d%d",&n,&m);
for(register int i=1;i<n;++i)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
eADD(u,v,w),eADD(v,u,w);
}
for(register int i=1;i<=m;++i)
scanf("%d",&K[i]);
sum=n;mxsiz[0]=inf;
getroot(1,0);
solve(root);
for(register int i=1;i<=m;++i)
if(cnt[i])
printf("AYE\n");
else
printf("NAY\n");
return 0;
}
Luogu P2634 [国家集训队]聪聪可可
Solution
还是套框架,这道题的话只需要分别记录一下对\(3\)取模为\(0,1,2\)的路径数量,那么最终符合条件的路径的条数显然为:
记得要对答案约分
Code
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <iostream>
#define maxn 20005
using namespace std;
typedef long long ll;
const int inf=0x7f7f7f7f;
int n,K;
int sum,f[maxn],siz[maxn],dis[3],dep[maxn],root,ans,vst[maxn];
struct edge
{
int u,v,w,nxt;
}g[maxn*2];
int head[maxn],ecnt;
void eADD(int u,int v,int w)
{
g[++ecnt].u=u;
g[ecnt].v=v;
g[ecnt].w=w;
g[ecnt].nxt=head[u];
head[u]=ecnt;
}
void getroot(int u,int fa)
{
siz[u]=1,f[u]=0;
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(v==fa||vst[v])continue;
getroot(v,u);
siz[u]+=siz[v];
f[u]=max(f[u],siz[v]);
}
f[u]=max(f[u],sum-siz[u]);
if(f[u]<f[root])root=u;
}
void getdis(int u,int fa)
{
++dis[dep[u]%3];
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(v==fa||vst[v])continue;
dep[v]=dep[u]+g[i].w;
getdis(v,u);
}
}
int calc(int u,int now)
{
dep[u]=now,dis[0]=dis[1]=dis[2]=0;
getdis(u,0);
return dis[1]*dis[2]*2+dis[0]*dis[0];
}
void solve(int u)
{
ans+=calc(u,0);
vst[u]=1;
for(register int i=head[u];i;i=g[i].nxt)
{
int v=g[i].v;
if(vst[v])continue;
ans-=calc(v,g[i].w);
sum=siz[v];
root=0;
getroot(v,0);
solve(root);
}
}
int gcd(int a,int b)
{
return b?gcd(b,a%b):a;
}
int main()
{
scanf("%d",&n);
for(register int i=1;i<n;++i)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
eADD(u,v,w),eADD(v,u,w);
}
sum=n;f[0]=inf;root=0;
getroot(1,0);
solve(root);
int d=gcd(ans,n*n);
printf("%d/%d\n",ans/d,n*n/d);
return 0;
}