POJ1741 Tree

Tree
Time Limit: 1000MS   Memory Limit: 30000K
     

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8
分析:点分治入门;
   思想主要是把通过分治思想,对通过当前树根与不通过树根讨论;
   通过树根的话,dfs得到所有子节点深度后,利用单调性O(N)得到点对;
   同时注意减去在同一子树里的点对,因为严格意义来说并没有经过树根;
   不通过树根的话,直接枚举子节点递归下去即可,注意每次都要找树重心,防止树退化成单链;
代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <climits>
#include <cstring>
#include <string>
#include <set>
#include <bitset>
#include <map>
#include <queue>
#include <stack>
#include <vector>
#include <cassert>
#include <ctime>
#define rep(i,m,n) for(i=m;i<=(int)n;i++)
#define inf 0x3f3f3f3f
#define mod 1000000007
#define vi vector<int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define ll long long
#define pi acos(-1.0)
#define pii pair<int,int>
#define sys system("pause")
#define ls (rt<<1)
#define rs (rt<<1|1)
#define all(x) x.begin(),x.end()
const int maxn=1e5+10;
const int N=1e4+10;
using namespace std;
ll gcd(ll p,ll q){return q==0?p:gcd(q,p%q);}
ll qmul(ll p,ll q,ll mo){ll f=0;while(q){if(q&1)f=(f+p)%mo;p=(p+p)%mo;q>>=1;}return f;}
ll qpow(ll p,ll q,ll mo){ll f=1;while(q){if(q&1)f=f*p%mo;p=p*p%mo;q>>=1;}return f;}
int n,m,k,t,sz,root,ans,s[maxn],p[maxn],q[maxn];
bool vis[maxn];
vi dep,e[maxn],f[maxn];
void getroot(int x,int y)
{
    int i;
    s[x]=1;p[x]=0;
    rep(i,0,e[x].size()-1)
    {
        int z=e[x][i],t=f[x][i];
        if(z==y||vis[z])continue;
        getroot(z,x);
        s[x]+=s[z];
        p[x]=max(p[x],s[z]);
    }
    p[x]=max(p[x],sz-s[x]);
    if(p[x]<p[root])root=x;
}
void getdep(int x,int y)
{
    dep.pb(q[x]);
    int i;
    rep(i,0,e[x].size()-1)
    {
        int z=e[x][i],t=f[x][i];
        if(z==y||vis[z])continue;
        q[z]=q[x]+t;
        getdep(z,x);
    }
}
int cal(int x,int y)
{
    dep.clear();q[x]=y;
    getdep(x,0);
    sort(dep.begin(),dep.end());
    int ret=0;
    for(int i=0,j=dep.size()-1;i<j;)
    {
        if(dep[i]+dep[j]<=m)ret+=j-i++;
        else j--;
    }
    return ret;
}
void gao(int x)
{
    int i;
    ans+=cal(x,0);
    vis[x]=true;
    rep(i,0,e[x].size()-1)
    {
        int y=e[x][i],z=f[x][i];
        if(!vis[y])
        {
            ans-=cal(y,z);
            p[0]=sz=s[y];
            getroot(y,root=0);
            gao(root);
        }
    }
}
int main(){
    int i,j;
    while(~scanf("%d%d",&n,&m))
    {
        if(!n&&!m)break;
        rep(i,1,n)e[i].clear(),f[i].clear();
        rep(i,1,n)vis[i]=false;
        rep(i,1,n-1)
        {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            e[x].pb(y);e[y].pb(x);
            f[x].pb(z);f[y].pb(z);
        }
        p[0]=sz=n;
        getroot(1,root=0);
        ans=0;
        gao(root);
        printf("%d\n",ans);
    }
    return 0;
}
posted @ 2017-08-23 21:05  mxzf0213  阅读(177)  评论(0编辑  收藏  举报