树的点分治
本题目是他讲的第一个例题;
我的理解:每次都找树的重心,计算以重心为根的子树之间所贡献的答案。不断这样下去;如果这棵树是一条链,那么就和快排,归并的线性分治法一样了。如果不是一条链那么就相当于,选中一个点,标记为使用过。然后树会被这个节点划分成多棵子树。然后这样分治下去。思想好理解。但是代码不是很好想!详见注解。
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <math.h>
using namespace std;
typedef long long int LL;
const int maxn=20000+100;
struct Node
{
int to,val,next;
}edge[maxn];
int first[maxn],sz,ans,idx,root,n,k;
bool vis[maxn];
/* ----------------------*/
/**
* 邻接表部分
*/
void init()
{
memset(first,-1,sizeof(first));
memset(vis,0,sizeof(vis));
sz=0;
}
void addedge(int s,int t,int val)
{
edge[sz].val=val,edge[sz].to=t,edge[sz].next=first[s];
first[s]=sz++;
}
/*****
* mi,用来找树的重心的时候作比较用的;
* mx[i]数组 代表i节点的子树中最大的size;
* size[i]代表i节点的子树的节点数量;
* dis[i]代表i好节点到根的深度;
* ****************************/
int mi,mx[maxn],size[maxn],dis[maxn];
void dfssize(int x,int pre) //x节点这个子树求size,mx;
{
size[x]=1;
mx[x]=0;
for(int i=first[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to,val=edge[i].val;
if(to!=pre&&!vis[to])
{
dfssize(to,x);
size[x]+=size[to];
if(size[to]>mx[x]) mx[x]=size[to];
}
}
}
void dfsroot(int rt,int x,int pre) //找rt这个 子树的重心
{
mx[x]=max(mx[x],size[rt]-size[x]);
if(mx[x]<mi) mi=mx[x],root=x;
for(int i=first[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(to!=pre&&!vis[to]) dfsroot(rt,to,x);
}
}
void dfsdis(int x,int pre,int dd) //计算子树 x的 dis 数组;
{
dis[idx++]=dd;
for(int i=first[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to,val=edge[i].val;
if(to!=pre&&!vis[to]) dfsdis(to,x,dd+val);
}
}
/**
* 计算 F(x)=(子树x的 depth(i)+depth(j)<=k 的对数;)
* calc(x,d);当d=0;返回的是x这颗子树F(x),
* 等d=某条边权值,calc(x,d) 返回值代表子树之间的F(x)
*/
int calc(int x,int d)
{
int res=0;
idx=0;
dfsdis(x,x,d);
sort(dis,dis+idx);
int i=0,j=idx-1;
while(i<j)
{
while(dis[i]+dis[j]>k&&i<j) j--;
res+=j-i;
i++;
}
return res;
}
void dfs(int x)
{
mi=n;
dfssize(x,x);
dfsroot(x,x,x);
ans+=calc(root,0);//当前节点
vis[root]=1;
for(int i=first[root];i!=-1;i=edge[i].next)
{
int to=edge[i].to,val=edge[i].val;
if(!vis[to])
{
ans-=calc(to,val);//减去当前节点的子树之间的
// printf(">> calc=%d\n",calc(root,val));
dfs(to);
}
}
}
int main()
{
while(scanf("%d%d",&n,&k)!=EOF)
{
if(n==0&&k==0) break;
init();
for(int i=1;i<n;i++)
{
int x,y,val;
scanf("%d%d%d",&x,&y,&val);
addedge(x,y,val);
addedge(y,x,val);
}
ans=0;
dfs(1);
printf("%d\n",ans);
}
return 0;
}