poj 1741 树的分治

思路:这题我是看 漆子超《分治算法在树的路径问题中的应用》写的。

附代码:

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define Maxn 10010
#define Maxm 20010
#define inf 0x7fffffff
using namespace std;
int head[Maxn],vi[Maxn],e,ans,num,k,n;
int mx[Maxn],mi,dis[Maxn],root,size[Maxn];
struct Edge{
    int u,v,val,next;
}edge[Maxm];
void init()
{
    memset(vi,0,sizeof(vi));
    memset(head,-1,sizeof(head));
    memset(mx,0,sizeof(mx));
    memset(dis,0,sizeof(dis));
    e=ans=0;
}
void add(int u,int v,int val)
{
    edge[e].u=u,edge[e].v=v,edge[e].val=val,edge[e].next=head[u],head[u]=e++;
    edge[e].u=v,edge[e].v=u,edge[e].val=val,edge[e].next=head[v],head[v]=e++;
}
void dfssize(int u,int fa)
{
    int i,v;
    size[u]=1;
    mx[u]=0;
    for(i=head[u];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&!vi[v])
        {
            dfssize(v,u);
            size[u]+=size[v];
            if(size[v]>mx[u]) mx[u]=size[v];
        }
    }
}
void dfsroot(int r,int u,int fa)
{
    int v,i;
    if(size[r]-size[u]>mx[u]) mx[u]=size[r]-size[u];
    if(mx[u]<mi) mi=mx[u],root=u;
    for(i=head[u];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&!vi[v])
        {
            dfsroot(r,v,u);
        }
    }
}
void dfsdis(int u,int d,int fa)
{
    int i,v;
    dis[++num]=d;
    for(i=head[u];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(v!=fa&&!vi[v])
        {
            dfsdis(v,d+edge[i].val,u);
        }
    }
}
int calc(int u,int d)
{
    int i,j,ret=0;
    num=0;
    dfsdis(u,d,0);
    sort(dis+1,dis+1+num);
    i=1;j=num;
    while(i<j)//单调求点对
    {
        while(dis[i]+dis[j]>k&&i<j)
            j--;
        ret+=j-i;
        i++;
    }
    return ret;
}
void dfs(int u)
{
    int i,v;
    mi=n;
    dfssize(u,0);
    dfsroot(u,u,0);
    ans+=calc(root,0);
    vi[root]=1;
    for(i=head[root];i!=-1;i=edge[i].next)
    {
        v=edge[i].v;
        if(!vi[v])
        {
            ans-=calc(v,edge[i].val);
            dfs(v);
        }
    }
}
int main()
{
    int i,j,u,v,val;
    while(scanf("%d%d",&n,&k)!=EOF,n||k)
    {
        init();
        for(i=1;i<n;i++)
        {
            scanf("%d%d%d",&u,&v,&val);
            add(u,v,val);
        }
        dfs(1);
        printf("%d\n",ans);
    }
    return 0;
}

 

posted @ 2013-08-14 20:44  fangguo  阅读(133)  评论(0编辑  收藏  举报