poj1741 Tree

Tree
Time Limit: 1000MS   Memory Limit: 30000K
Total Submissions: 25816   Accepted: 8586

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source

题目大意:求树上距离 ≤ k的点对数.
分析:点分治模板题.大体步骤就是找重心,然后求跨过重心的答案,接下来对重心的每个子树进行分治.每次找重心和距离都必须判断当前点是否走过,不然可能会再次回到重心.
#include <vector>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 20010;
int head[maxn],to[maxn],nextt[maxn],w[maxn],tot = 1,vis[maxn],d[maxn];
int ans,root,sizee[maxn],f[maxn],sum,k,n;
vector <int> q;

void add(int x,int y,int z)
{
    w[tot] = z;
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void getroot(int u,int fa)
{
    f[u] = 0;
    sizee[u] = 1;
    for (int i = head[u];i;i = nextt[i])
    {
        int v = to[i];
        if (v == fa || vis[v])
            continue;
        getroot(v,u);
        sizee[u] += sizee[v];
        f[u] = max(f[u],sizee[v]);
    }
    f[u] = max(f[u],sum - sizee[u]);
    if (f[u] < f[root])
        root = u;
}

void getdep(int u,int fa,int p)
{
    d[u] = p;
    q.push_back(d[u]);
    for (int i = head[u];i; i = nextt[i])
    {
        int v = to[i];
        if (v == fa || vis[v])
            continue;
        getdep(v,u,p + w[i]);
    }
}

int calc(int u,int p)
{
    int res = 0;
    q.clear();
    getdep(u,0,p);
    sort(q.begin(),q.end());
    int l = 0,r = q.size() - 1;
    while (l < r)
    {
        if (q[l] + q[r] <= k)
            res += r - l++;
        else
            r--;
    }
    return res;
}

void dfs(int u)
{
    vis[u] = 1;
    ans += calc(u,0);
    for (int i = head[u]; i ; i = nextt[i])
    {
        int v = to[i];
        if (!vis[v])
        {
            ans -= calc(v,w[i]);
            f[0] = sum = sizee[v];
            getroot(v,root = 0);
            dfs(root);
        }
    }
}

int main()
{
    while (scanf("%d%d",&n,&k) == 2 && n + k)
    {
        memset(head,0,sizeof(head));
        memset(vis,0,sizeof(vis));
        tot = 1;
        ans = 0;
        for (int i = 1; i < n; i++)
        {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            add(a,b,c);
            add(b,a,c);
        }
        f[0] = sum = n;
        root = 0;
        getroot(1,0);
        dfs(root);
        printf("%d\n",ans);
    }

    return 0;
}

 

posted @ 2018-01-04 22:54  zbtrs  阅读(172)  评论(0编辑  收藏  举报