【POJ1741】Tree-树的点分治

测试地址:Tree

题目大意:给定一棵有N(N≤10000)个节点的带边权的树,我们称一个点对是合法的当且仅当两个点不相同且它们之间的距离≤K,求合法点对的数目。

做法:既是树分治的论文题,又是男人八题的其中一题,妙啊......

这道题需要用到树的点分治。

因为,我们很容易想到O(N^2)的暴力,然而对于N=10000的数量级根本束手无策。

那么要怎么办呢?

我们将这棵树按根节点分治,那么两点间的路径就可以分成两种情况:过根节点的和不过根节点的。不过根节点的情况可以递归进子树求出,我们就只用考虑过根节点的情况了。设点i与根节点的距离为dis[i],那么我们要找的就是:满足dis[i]+dis[j]≤K且i,j不属于同一棵子树的点对(i,j)的数目。但是这样太难算了,我们可以把问题分解成两个问题,令X=满足dis[i]+dis[j]≤K的点对(i,j)数目,Y=满足dis[i]+dis[j]≤K且i,j属于同一棵子树的点对(i,j)数目,那么原问题答案就是X-Y。我们发现X和Y都可以转化成“给定A,求满足A[i]+A[j]≤K的点对(i,j)数目”这个问题,这个问题的解决就只需要将A排序,设B[i]为使得A[i]+A[x]≤K的最大的x,根据单调性,B[i]是单调不增的,因此解决这个问题的复杂度为O(NlogN),可以证明没有更好的办法了。

但是,如果遇到极端情况,分治的深度可能会达到N,那么就比O(N^2)的暴力还差了,所以我们可以每次分治时,找到树的重心(使得以它为树的根节点时,节点数最多的子树节点数最小的点)作为根节点,可以证明这样分治的深度不大于logN,那么问题的均摊复杂度就优化到了O(Nlog^2 N),可以通过此题。

犯二的地方:各种细节写错导致重心没找对......TLE到爆炸......

以下是本人代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define inf 1000000000
using namespace std;
int n,k,tot,first[10010],maxp,ans;
int a[10010],p[10010],fa[10010],siz[10010],dis[10010];
struct edge {int v,d,next;} e[20010];
bool vis[10010];

bool cmp(int a,int b)
{
  return a<b;
}

void insert(int a,int b,int d)
{
  e[++tot].v=b;
  e[tot].d=d;
  e[tot].next=first[a];
  first[a]=tot;
}

void dfs(int v)
{
  siz[v]=1;
  a[++a[0]]=dis[v];
  p[a[0]]=v;
  for(int i=first[v];i;i=e[i].next)
    if (e[i].v!=fa[v]&&!vis[e[i].v])
	{
	  fa[e[i].v]=v;
	  dis[e[i].v]=dis[v]+e[i].d;
	  dfs(e[i].v);
	  siz[v]+=siz[e[i].v];
	}
}

int find(int v)
{
  int s,mx=inf;
  a[0]=0;fa[v]=0;dis[v]=0;dfs(v);
  for(int i=1;i<=siz[v];i++)
  {
    int x=p[i],maxsiz=0,sumsiz=0;
	for(int j=first[x];j;j=e[j].next)
	  if (e[j].v!=fa[x])
	  {
	    maxsiz=max(maxsiz,siz[e[j].v]);
		sumsiz+=siz[e[j].v];
	  }
	maxsiz=max(maxsiz,siz[v]-sumsiz);
	if (maxsiz<mx) mx=maxsiz,s=x;
  }
  return s;
}

int work(int v,int start)
{
  a[0]=0;dis[v]=start;dfs(v);
  sort(a+1,a+siz[v]+1,cmp);
  int r=1,sum=0;
  while(a[1]+a[r]<=k&&r<=siz[v]) r++;
  r--;
  for(int i=1;i<=siz[v];i++)
  {
    while(a[i]+a[r]>k) r--;
	if (i>=r) break;
	sum+=r-i;
  }
  return sum;
}

void solve(int v)
{
  v=find(v);
  fa[v]=0;
  ans+=work(v,0);
  for(int i=first[v];i;i=e[i].next)
    if (!vis[e[i].v]) ans-=work(e[i].v,dis[e[i].v]);
  vis[v]=1;
  for(int i=first[v];i;i=e[i].next)
    if (!vis[e[i].v]) solve(e[i].v);
}

int main()
{
  while(scanf("%d%d",&n,&k)&&n)
  {
    tot=0;
	memset(first,0,sizeof(first));
    memset(vis,0,sizeof(vis));
	for(int i=1,x,y,d;i<n;i++)
	{
	  scanf("%d%d%d",&x,&y,&d);
	  insert(x,y,d),insert(y,x,d);
	}
    
	ans=0;
    solve(1);
	
	printf("%d\n",ans);
  }
  
  return 0;
}


posted @ 2017-04-26 17:15  Maxwei_wzj  阅读(69)  评论(0编辑  收藏  举报