poj1741 Tree

题目描述:

给出一棵树,求距离不超过k的点对对数。

题解:

点分治板子题。

对于一棵树,我们可以$O(n)$时间求出其重心。(重心:切开这个点得到一堆树,所有树的最大大小最小)

如果我们将一棵树不断找重心->分开->找重心……,我们可以将这棵树分成若干部分。

这个东西叫点分治。

因此将一棵树不断分开后,每个点都会作一次中心。

如果我们将第一层找到的重心称为一级重心,第二层找到的重心称为二级重心……的话,

对于任意一条树链,经过的最高级重心只能有一个,而且从这个重心搜索一定可以搜到这条树链。

这就是点分治能处理树链问题的原因。

对于本题,我们可以求出以x为重心的树内距离x为i的有多少个点。

然后对于和$<=k$的更新答案。

但是还有一种情况,就是路径a为x->y->……,路径b为x->y->……,而a+b可以更新答案。

这样我们的答案其实是偏大的。

因此我们可以搞一下容斥,将y相同的直接搞掉,这样的话得出的就是答案了。

代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 10050
inline int rd()
{
    int f=1,c=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();}
    return f*c;
}
int n,k,hed[N],cnt;
struct EG
{
    int to,nxt,v;
}e[2*N];
void ae(int f,int t,int v)
{
    e[++cnt].to = t;
    e[cnt].nxt = hed[f];
    e[cnt].v = v;
    hed[f] = cnt;
}
int rt,sum,mx[N],mrk[N],siz[N],ans;
void get_rt(int u,int fa)
{
    mx[u] = 0;siz[u] = 1;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to==fa||mrk[to])continue;
        get_rt(to,u);
        siz[u]+=siz[to];
        if(siz[to]>mx[u])mx[u]=siz[to];
    }
    mx[u]=max(mx[u],sum-siz[u]);
    if(mx[u]<mx[rt])rt=u;
}
int st[N],tl;
void dfs(int u,int fa,int dep)
{
    st[++tl] = dep;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to==fa||mrk[to])continue;
        dfs(to,u,dep+e[j].v);
    }
}
int sol(int u,int fa,int dep)
{
    tl=0;
    dfs(u,fa,dep);
    sort(st+1,st+1+tl);
    int l = 1,r = tl,ret = 0;
    while(l<r)
    {
        if(st[l]+st[r]<=k)
        {
            ret+=r-l;
            l++;
        }else r--;
    }
    return ret;
}
void work()
{
    ans+=sol(rt,0,0);
    mrk[rt]=1;
    for(int j=hed[rt];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(mrk[to])continue;
        ans-=sol(to,rt,e[j].v);
    }
    for(int j=hed[rt];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(mrk[to])continue;
        rt=0,sum=siz[to];
        get_rt(to,0);
        work();
    }
}
void init()
{
    rt=sum=ans=cnt=0;
    memset(mrk,0,sizeof(mrk));
    memset(hed,0,sizeof(hed));
}
int main()
{
    mx[0] = 0x7fffffff;
    while(1)
    {
        n=rd(),k=rd();
        init();
        if(!n&&!k)break;
        for(int f,t,v,i=1;i<n;i++)
        {
            f = rd(),t = rd(),v = rd();
            ae(f,t,v),ae(t,f,v);
        }
        rt=0;sum=n;
        get_rt(1,0);
        work();
        printf("%d\n",ans);
    }
    return 0;
}

 

posted @ 2018-12-28 12:53  LiGuanlin  阅读(113)  评论(0编辑  收藏  举报