[POJ1741]Tree

题目描述 Description
 Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 
输入描述 Input Description

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

输出描述 Output Description
For each test case output the answer on a single line.
样例输入 Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
样例输出 Sample Output
8
数据范围及提示 Data Size & Hint
 

一道比较裸的点分治题。对于每一棵树,先找到他的重心,然后算出他的所有子孙到他的dis,统一放在dis数组中,下面只需在数组中找数对满足和为k就好了,这个能在O(n)的时间内解决。但是这样会有一个问题。对于在同一棵子树上的点,路径并不是从这个点跑到重心,再跑下来,于是我们需要去重。去重之后对重心打上标记,表示以后不能再用这个点了,之后递归的处理每一个子树。一道破题调了好久,后来发现是在getd函数中传参数时出了BUG,看来以后函数的参数还是要有顺序的!

  1 #include<iostream>
  2 #include<cmath>
  3 #include<algorithm>
  4 #include<cstring>
  5 #include<cstdio>
  6 #include<queue>
  7 using namespace std;
  8 typedef long long LL;
  9 #define Pi acos(-1.0)
 10 #define mem(a,b) memset(a,b,sizeof(a))
 11 inline int read()
 12 {
 13     int x=0,f=1;char c=getchar();
 14     while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
 15     while(isdigit(c)){x=x*10+c-'0';c=getchar();}
 16     return x*f;
 17 }
 18 const int maxn=10010;
 19 struct Edge
 20 {
 21     int u,v,w,next;
 22     Edge() {}
 23     Edge(int _1,int _2,int _3,int _4) : u(_1),v(_2),w(_3),next(_4) {}
 24 }e[2*maxn];
 25 int first[maxn],n,k,a,b,c,ans,size[maxn],masize[maxn],now_size,root,dis[maxn],end;
 26 bool vis[maxn]; 
 27 void addEdge(int i,int a,int b,int c)
 28 {
 29     e[i]=Edge(a,b,c,first[a]);
 30     first[a]=i;
 31 }
 32 void gets(int u,int pa)
 33 {
 34     size[u]=1;
 35     masize[u]=0;
 36     for(int i=first[u];i!=-1;i=e[i].next)
 37         if(!vis[e[i].v] && e[i].v!=pa)
 38         {
 39             gets(e[i].v,u);
 40             size[u]+=size[e[i].v];
 41             masize[u]=max(size[e[i].v],masize[u]);
 42         }
 43 }
 44 void getr(int r,int u,int pa)
 45 {
 46     masize[u]=max(masize[u],size[r]-size[u]);
 47     if(masize[u]<now_size)now_size=masize[u],root=u;
 48     for(int i=first[u];i!=-1;i=e[i].next)
 49         if(!vis[e[i].v] && e[i].v!=pa)getr(r,e[i].v,u);
 50 }
 51 void getd(int u,int pa,int d)
 52 {
 53     dis[end++]=d;
 54     for(int i=first[u];i!=-1;i=e[i].next)
 55         if(!vis[e[i].v] && e[i].v!=pa)getd(e[i].v,u,d+e[i].w);
 56 }
 57 int calc(int u,int d)
 58 {
 59     end=0;
 60     getd(u,-1,d);
 61     int ret=0,l=0,r=end-1;
 62     sort(dis,dis+end); 
 63     while(l<r)
 64     {
 65         while(dis[l]+dis[r]>k && l<r)r--;
 66         ret+=(r-l);l++;
 67     }
 68     return ret;
 69 }
 70 void dfs(int u)
 71 {
 72     now_size=n;
 73     gets(u,-1);
 74     getr(u,u,-1);
 75     ans+=calc(root,0);
 76     vis[root]=1;
 77     for(int i=first[root];i!=-1;i=e[i].next)
 78         if(!vis[e[i].v])
 79         {
 80             ans-=calc(e[i].v,e[i].w);
 81             dfs(e[i].v);
 82         }
 83     return;
 84 }
 85 void init(){mem(first,-1);mem(size,0);mem(masize,0);mem(vis,0);mem(dis,0);ans=0;}
 86 int main()
 87 {
 88     while(scanf("%d%d",&n,&k)!=EOF)
 89     {
 90         if(n==0 && k==0)break;
 91         init();
 92         for(int i=0;i<n-1;i++)
 93         {
 94             a=read();b=read();c=read();
 95             addEdge(i*2,a,b,c);addEdge(i*2+1,b,a,c);
 96         }
 97         dfs(1);
 98         printf("%d\n",ans);    
 99     }
100     return 0;
101 }

 

posted @ 2017-01-18 12:53  小飞淙的云端  阅读(153)  评论(0编辑  收藏  举报