poj 2114 Boatherds (树分治)

链接:http://poj.org/problem?id=2114

题意:

求树上距离为k的点对数量;

思路:

点分治。。

 

实现代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define inf 0x7fffffff
const int M = 1e5+10;

struct node{
    int to,next,w;
}e[M<<2];

int n,m;
int vis[M],dis[M],d[M],siz[M],f[M],cnt,head[M],sum,root,k,ans;

void init(){
    cnt = 0;
    ans = 0;
    memset(head,0,sizeof(head));
    memset(vis,0,sizeof(vis));
}

void add(int u,int v,int w){
    e[++cnt].to = v;e[cnt].w = w;e[cnt].next = head[u];head[u] = cnt;
}

void get_root(int u,int fa){
    siz[u] = 1; f[u] = 0;
    for(int i = head[u];i;i=e[i].next){
        int v = e[i].to;
        if(v != fa&&!vis[v]){
            get_root(v,u);
            siz[u] += siz[v];
            f[u] = max(f[u],siz[u]);
        }
    }
    f[u] = max(f[u],sum - siz[u]);
    if(f[u] < f[root]) root = u;
    return ;
}

void get_dis(int u,int fa){
    if(d[u] <= k) dis[++dis[0]] = d[u];
    for(int i = head[u];i;i=e[i].next){
        int v = e[i].to;
        if(v != fa&&!vis[v]){
            d[v] = d[u] + e[i].w;
            get_dis(v,u);
        }
    }
    return ;
}

int cal(int u,int c){
    d[u] = c; dis[0] = 0;
    get_dis(u,0);
    sort(dis+1,dis+dis[0]+1);
    int l = 1,r = dis[0],ans = 0;
    while(l < r){
        if(dis[l] + dis[r] < k) l++;
        else if(dis[l] + dis[r] > k) r--;
        else{
            if(dis[l] == dis[r]){
                ans += (r-l+1)*(r-l)/2;
                break;
            }
            else {
                int i = l,j = r;
                while(dis[i] == dis[l]) i++;
                while(dis[j] == dis[r]) j--;
                ans += (i-l)*(r-j);
                l = i; r = j;
            }
        }
    }
    return ans;
}

void solve(int u){
    ans += cal(u,0);
    vis[u] = 1;
    for(int i = head[u];i;i=e[i].next){
        int v = e[i].to;
        if(vis[v]) continue;
        ans -= cal(v,e[i].w);
        sum = siz[v];
        root = 0;
        get_root(v,0);
        solve(root);
    }
}

int main()
{
    int u,v,w;
    while(scanf("%d",&n)&&n){
        init();
        for(int i = 1;i <= n;i ++){
            int x,v;
            while(scanf("%d",&x)&&x){
                scanf("%d",&v);
                add(i,x,v); add(x,i,v);
            }
        }
        int x;
        while(scanf("%d",&x)&&x){
            memset(vis,0,sizeof(vis));
            k = x; ans = root = 0;
            sum = n; f[0] = inf;
            get_root(1,0);
            solve(root);
            if(ans) printf("AYE\n");
            else printf("NAY\n");
        }
        printf(".\n");
    }
}

 

posted @ 2018-09-30 12:27  冥想选手  阅读(283)  评论(0编辑  收藏  举报