Day8 - F - Tree POJ - 1741

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

思路:点分治板子题,提供两个blog
https://blog.csdn.net/qq_39553725/article/details/77542223
https://www.cnblogs.com/bztMinamoto/p/9489473.html
typedef long long LL;
typedef pair<LL, LL> PLL;
 
const int maxm = 1e4+5;

struct Node {
    int v, next, val;
} Nodes[maxm*2];

int head[maxm], cnt, siz[maxm], mxson[maxm], dis[maxm], root, mxsum, rootsum, points, n, k;
bool vis[maxm];
LL ans;

void init() {
    ans = 0; cnt = 0;
    memset(vis, false, sizeof(vis)), memset(head, 0, sizeof(head));
}

void addedge(int u, int v, int val) {
    Nodes[++cnt].v = v;
    Nodes[cnt].val = val;
    Nodes[cnt].next = head[u];
    head[u] = cnt;
}

void getroot(int u, int fa) {
    mxson[u] = 0, siz[u] = 1;
    for(int i = head[u]; i; i = Nodes[i].next) {
        int v = Nodes[i].v;
        if(v == fa || vis[v]) continue;
        getroot(v, u);
        siz[u] += siz[v];
        mxson[u] = max(mxson[u], siz[v]);
    }
    mxson[u] = max(mxson[u], rootsum - siz[u]);
    if(mxson[u] < mxsum) {
        root = u, mxsum = mxson[u];
    }
}

void getdist(int u, int fa, int dist) {
    dis[++points] = dist;
    for(int i = head[u]; i; i = Nodes[i].next) {
        int v = Nodes[i].v;
        if(v == fa || vis[v]) continue;
        getdist(v, u, dist+Nodes[i].val);
    }
}

int solve(int rt, int val) {
    points = 0;
    getdist(rt, 0, val);
    int l = 1, r = points, t = 0;
    sort(dis+1, dis+1+points);
    while(l <= r) {
        if(dis[l] + dis[r] <= k) {
            t += r-l;
            l++;
        } else
            r--;
    }
    return t;
}

void Divide(int rt) {
    ans += solve(rt, 0);
    vis[rt] = true;
    for(int i = head[rt]; i; i = Nodes[i].next) {
        int v = Nodes[i].v;
        if(vis[v]) continue;
        ans -= solve(v, Nodes[i].val);
        rootsum = siz[v];
        root = 0; mxsum = 0x3f3f3f3f;
        getroot(v, 0);
        Divide(root);
    }
}

int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    while(cin >> n >> k && n+k) {
        init();
        int u, v, val;
        for(int i = 0; i < n-1; ++i) {
            cin >> u >> v >> val;
            addedge(u, v, val), addedge(v, u, val);
        }
        mxsum = 0x3f3f3f3f; rootsum = n;
        getroot(1,0);
        Divide(root);
        cout << ans << "\n";
    }
    return 0;
}
View Code

 


posted @ 2020-01-25 13:05  GRedComeT  阅读(140)  评论(0编辑  收藏  举报