【树形+点分治】POJ 1741 Tree

通道:http://poj.org/problem?id=1741

题意:有多少对[u,v]的距离小于K

思路:

将无根树转化成有根树进行观察。满足条件的点对有两种情况:两个点的路径横跨树根,两个点位于同一颗子树中。

如果我们已经知道了此时所有点到根的距离a[i],a[x] + a[y] <= k的(x, y)对数就是结果,这个可以通过排序之后O(n)的复杂度求出。然后根据分治的思想,分别对所有的儿子求一遍即可,但是这会出现重复的——当前情况下两个点位于一颗子树中,那么应该将其减掉(显然这两个点是满足题意的,为什么减掉呢?因为在对子树进行求解的时候,会重新计算)。在进行分治时,为了避免树退化成一条链而导致时间复杂度变为O(N^2),每次都找树的重心,这样,所有的子树规模就会变的很小了。时间复杂度O(Nlog^2N)。

代码:

#include <cstdio>
#include <cmath>
#include <cstring>
#include <vector>
#include <algorithm>

using namespace std;

typedef long long ll;

const int MAX_N = 100007; 

struct Node {
    int v, w, nxt;
    Node () {
        
    }
    Node (int _v, int _w, int _n) {
        v = _v;
        w = _w;
        nxt = _n;
    }
};

int n, k;
int head[MAX_N], edgecnt;
Node G[MAX_N << 1];
bool del[MAX_N];

void Clear() {
    edgecnt = 0;
    memset(head, -1, sizeof head);
    memset(del, 0, sizeof del);
}

void add(int u, int v, int w) {
    G[edgecnt] = Node(v, w, head[u]);
    head[u] = edgecnt++;
}

int son[MAX_N], opt[MAX_N];
vector<int> alln;

void dfs(int u,int fa) {
    alln.push_back(u);
    son[u] = 1; opt[u] = 0;
    for(int i = head[u]; ~i; i = G[i].nxt) {
        int v = G[i].v;
        if(del[v] || v == fa) continue;
        dfs(v, u);
        son[u] += son[v];
        opt[u] = max(opt[u], son[v]);
    }
}

int getCenter(int u) {
    alln.clear();
    dfs(u, -1);
    int mx = 0, ans = -1;
    int sz = alln.size();
    for(int i = 0; i < sz; ++i)  {
        int v = alln[i];
        if(ans == -1) ans = v, mx = max(opt[v], sz - son[v]);
        else  {
            if(max(opt[v], sz - son[v]) < mx) {
                mx = max(opt[v], sz - son[v]);
                ans = v;
            }
        }
    }
    return ans;
}

int tot, D[MAX_N];
void getDist(int u, int fa, int w) {
    D[tot++] = w;
    for(int i = head[u]; ~i; i = G[i].nxt) {
        int v = G[i].v;
        if(del[v] || v == fa) continue;
        getDist(v, u, w + G[i].w);
    } 
}

ll calc() {
    sort(D, D + tot);
    ll ans = 0;
    int l = 0, r = tot - 1;
    while (l < r) {
        if (D[l] + D[r] <= k) ans += r - l++;
        else --r;
    }
    return ans;
}

ll ans;
void work(int u) {
    u = getCenter(u);
    tot = 0; getDist(u, -1, 0);
    ans += calc();
    for(int i = head[u]; ~i; i = G[i].nxt) {
        int v = G[i].v;
        if(del[v]) continue;
        tot = 0;
        getDist(v, u, G[i].w);
        ans -= calc();
    }
    del[u] = true;
    for(int i = head[u]; ~i; i =G[i].nxt) {
        int v = G[i].v;
        if(del[v]) continue;
        work(v);
    }
}

int main() {
    while (2 == scanf("%d%d", &n, &k)) {
        if (0 == n && 0 == k) break;
        Clear();
        for (int i = 1; i < n; ++i) {
            int u, v, w;
            scanf("%d%d%d", &u, &v, &w);
            add(u, v, w), add(v, u, w);
        }
        ans = 0;
        work(1);
        printf("%lld\n", ans);
    }
    return 0;
}
View Code

 

posted @ 2014-11-03 19:26  mithrilhan  阅读(159)  评论(0编辑  收藏  举报