POJ1741 tree 点分治

题意:给定一棵树,求树中长度小于等于K的路径条数。

题解:

分治重在思想,没有固定的算法——对于本题中某一子树上的任意一条路径,其要么是经过该子树的根,要么不经过,不经过的情况分治到子树上,我们只用考虑经过子树的根的情况。

由于经过了根节点,所以该路径的起始点一定是在根节点的两颗子树上,但直接求解方案数较为困难,我们可以进行如下转化:

|dist(i,j)<=k且i,j不在同一棵子树上|=|dist(i,j)<=k|-Σ|dist(i,j)<=k且i,j在同一棵子树上|

第一项把所有子节点到根的距离放在一起排序后可以用O(n)的时间解决,第二项把该子树的某一子树中所有子节点到根的距离放在一起排序后也可以用O(n)的时间解决。

为了防止出现一条链的情况,每次分治子树的时候从树的重心向下分治,这样就能将总的分治层数降到logN

而每一层最多向下扩展N个节点,因此瓶颈就是排序算法的选择,使用快速排序是NlogN,而基数排序可以降到N,因此总的复杂度使用快排就是Nlog^2N

#include <cstdio>
#include <climits>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;

const int MAXN=30000+2;
struct EDGE{
    int u,w;
    EDGE *next;
    EDGE(){}
    EDGE(int _u,int _w,EDGE *_next):u(_u),w(_w),next(_next){}
}mem[2*MAXN];
struct NODE{
    int c,m;
    EDGE *child;
}node[MAXN];
int N,K,ans,cnt,root,n,m,dist[MAXN];
bool flag[MAXN];

void Insert(int u,int v,int w){ node[u].child=&(mem[cnt++]=EDGE(v,w,node[u].child));}

void Find_Size(int u,int f){
    node[u].c=1,node[u].m=0;
    for(EDGE *p=node[u].child;p;p=p->next)
        if(p->u!=f && !flag[p->u]){
            Find_Size(p->u,u);
            node[u].c+=node[p->u].c;
            node[u].m=max(node[u].m,node[p->u].m);
        }
}

void Find_Root(int r,int u,int f){
    node[u].m=max(node[u].m,node[r].c-node[u].c);
    if(node[u].m<m) m=node[u].m,root=u;
    for(EDGE *p=node[u].child;p;p=p->next)
        if(p->u!=f && !flag[p->u]) Find_Root(r,p->u,u);
}

void Find_Dist(int u,int d,int f){
    dist[++n]=d;
    for(EDGE *p=node[u].child;p;p=p->next)
        if(p->u!=f && !flag[p->u]) Find_Dist(p->u,d+p->w,u);
}

int Calc(int u,int d){
    int ret=0;
    n=0,Find_Dist(u,d,0);
    sort(dist+1,dist+n+1);

    for(int i=1,j=n;i<j;i++){
        while(dist[i]+dist[j]>K && i<j) j--;
        ret+=j-i;
    }
    return ret;
}

void DFS(int u){
    m=INT_MAX;
    Find_Size(u,0),Find_Root(u,u,0);
    ans+=Calc(root,0),flag[root]=1;
    for(EDGE *p=node[root].child;p;p=p->next)
        if(!flag[p->u]){
            ans-=Calc(p->u,p->w);
            DFS(p->u);
        }
}

int main(){
    while(scanf("%d %d",&N,&K)!=EOF){
        if(!N) break;

        ans=cnt=0;
        memset(flag,0,sizeof(flag));
        memset(node,0,sizeof(node));

        for(int i=1,u,v,w;i<N;i++){
            scanf("%d %d %d",&u,&v,&w);
            Insert(u,v,w),Insert(v,u,w);
        }

        DFS(1);
        printf("%d\n",ans);
    }

    return 0;
}
View Code

 

posted @ 2017-02-27 00:34  WDZRMPCBIT  阅读(141)  评论(0编辑  收藏  举报