Tree
参考 xk 老哥的博客:POJ 1741 Tree 点分治
找重心:
void getrt(int fa,int u,int num) //num指的是这个节点的子树中有多少个节点
{
siz[u]=1;
int maxnum=0; //记录最大子树的节点个数
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]||v==fa) continue;
getrt(u,v,num);
siz[u]+=siz[v];
maxnum=max(maxnum,siz[v]);
}
maxnum=max(maxnum,num-siz[u]); //num-siz[u]表示的是某节点的反向子树(反向指的是沿着递归方向反方向)
if(maxnum<rtsiz) rtsiz=maxnum,rt=u; //更新重心和重心的最大子树
}
找到重心之后 dfs 计算子树上每个点距离重心的距离:
int d[maxn],dcnt;
void dfs(int fa,int u,int w)
{
d[++dcnt]=w;
siz[u]=1;
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]||fa==v) continue;
dfs(u,v,w+e[i].w);
siz[u]+=siz[v];
}
}
根据每个点到重心的距离进行排序,并计算有多少满足条件的点:
int cal()
{
sort(d+1,d+1+dcnt);
int l=1,r=dcnt,ret=0;
while(l<r)
if(d[l]+d[r]<=m) ret+=r-l,l++;
else r--;
return ret;
}
重点来了!分治
void solve(int u,int num)
{
if(num<=1) return;
rtsiz=inf;
getrt(0,u,num);
vis[rt]=1;
dcnt=0;
dfs(0,rt,0);
ans+=cal();
/*第一次cal的时候可能会把根节点的同一颗子树上的两个点d[l]+d[r]<=m记录进去,
但实际上这样的两个点不应该在本次cal中统计进去,因为在后面的分治中会进行删去
本次cal与下面for中的cal加起来统计的是经过该点u的满足条件的路径*/
for(int i=head[rt];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]) continue;
dcnt=0;
dfs(0,v,e[i].w);
ans-=cal();
/*在这个地方把不经过点u但满足条件的路径进行删除*/
}
for(int i=head[rt];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]) continue;
solve(v,siz[v]);
/*进行分治,也就是向下进行查询*/
}
}
代码:
// Created by CAD on 2019/8/14.
#include <iostream>
#include <cstdio>
#include <algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=1e5+100;
int siz[maxn],vis[maxn],head[maxn],tot;
int n,m,ans;
struct edge{
int to,next,w;
}e[maxn<<1];
void add(int u,int v,int w)
{
e[++tot].to=v;
e[tot].w=w;
e[tot].next=head[u];
head[u]=tot;
}
int rtsiz,rt;
void getrt(int fa,int u,int num)
{
siz[u]=1;
int maxnum=0;
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]||v==fa) continue;
getrt(u,v,num);
siz[u]+=siz[v];
maxnum=max(maxnum,siz[v]);
}
maxnum=max(maxnum,num-siz[u]);
if(maxnum<rtsiz) rtsiz=maxnum,rt=u;
}
int d[maxn],dcnt;
void dfs(int fa,int u,int w)
{
d[++dcnt]=w;
siz[u]=1;
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]||fa==v)continue;
dfs(u,v,w+e[i].w);
siz[u]+=siz[v];
}
}
int cal()
{
sort(d+1,d+1+dcnt);
int l=1,r=dcnt,ret=0;
while(l<r)
if(d[l]+d[r]<=m) ret+=r-l,l++;
else r--;
return ret;
}
void solve(int u,int num)
{
if(num<=1) return;
rtsiz=inf;
getrt(0,u,num);
vis[rt]=1;
dcnt=0;
dfs(0,rt,0);
ans+=cal();
for(int i=head[rt];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]) continue;
dcnt=0;
dfs(0,v,e[i].w);
ans-=cal();
}
for(int i=head[rt];~i;i=e[i].next)
{
int v=e[i].to;
if(vis[v]) continue;
solve(v,siz[v]);
}
}
int main()
{
int u,v,w;
while(~scanf("%d%d",&n,&m)&&n+m)
{
tot=ans=0;
for(int i=1;i<=n;++i)
vis[i]=0,head[i]=-1;
for(int i=1;i<n;++i)
scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
solve(1,n);
cout<<ans<<endl;
}
return 0;
}
CAD加油!欢迎跟我一起讨论学习算法,QQ:1401650042