POJ1741:Tree——题解+树分治简要讲解

http://poj.org/problem?id=1741

题目大意:给一棵树,求点对间距离<=k的个数。

————————————————————

以这道题为例记录一下对于树分治的理解。

树分治分为两类,一类是基于点的分治,一类是基于边的分治。

后者与树链剖分很相似,但是一般用不上,这里讲的是前者。

我们一般进行树分治找的点都是这棵树的重心(即子树最大者最小的点),我们每次操作都做与这个点相关的路径,然后删除这个点再重新寻找。

分重心的好处在于我们近似的将树分成了两份,类似于二分,其深度不超过O(logn)(其实有严格证明的,但是我太弱了,不会写)

分完重心的操作大致三种

1.找u,v,其中u,v在重心s的同一棵子树上(这种情况直接忽略,因为看下面的操作我们就能明白我们可以递归的完成这个操作)

2.找u,v,其中u,v在重心s的两棵子树上。

3.找u,查找u到重心s的路径。

我们发现3操作和2操作很相似,我们直接讨论2操作。

显然我们在2操作的路径当中不可避免的要经过s,所以我们从s开始bfs,求出每个点i到s的距离dis[i],我们的路径长度即为dis[u]+dis[v]。

3操作同理只是变成了dis[u]+dis[s],其中dis[s]=0.

这里提供一种简要算法:我们在求完dis之后对我们求的dis排序,这样我们就可以快速的求出点对距离<=k的个数。

但是这样就不可避免的要判重,为什么呢?

废话你这样排不就有可能把1操作的一部分点对先算了一遍,这样明显会导致答案变大。

那怎么办呢?我们对于每一棵子树,再删掉我们通过2操作得到的点对即可。

(现将s删掉,s的儿子dis[u]不变的情况下以u为起点bfs求点对,则这些点对就是在同一棵子树当中被计算的重复的点对,减去即可。)

#include<cmath>
#include<cstdio>
#include<queue>
#include<cctype>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=10001;
inline int read(){
    int X=0,w=0; char ch=0;
    while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
    while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
struct node{
    int w;
    int to;
    int nxt;
}edge[N*2];
int cnt,n,k,head[N],q[N],dis[N],size[N],son[N],d[N],fa[N];
ll ans;
bool vis[N];
void add(int u,int v,int w){
    cnt++;
    edge[cnt].to=v;
    edge[cnt].w=w;
    edge[cnt].nxt=head[u];
    head[u]=cnt;
    return;
}
int calcg(int st){
    int r=0,g,maxn=n;
    q[++r]=st;
    fa[st]=0;
    for(int l=1;l<=r;l++){
    int u=q[l];
    size[u]=1;
    son[u]=0;
    for(int i=head[u];i;i=edge[i].nxt){
        int v=edge[i].to;
        if(vis[v]||v==fa[u])continue;
        fa[v]=u;
        q[++r]=v;
    }
    }
    for(int l=r;l>=1;l--){
    int u=q[l],v=fa[u];
    if(r-size[u]>son[u])son[u]=r-size[u];
    if(son[u]<maxn)g=u,maxn=son[u];
    if(!v)break;
    size[v]+=size[u];
    if(size[u]>son[v])son[v]=size[u];
    }
    return g;
}
inline ll calc(int st,int L){
    int r=0,num=0;
    q[++r]=st;
    dis[st]=L;
    fa[st]=0;
    for(int l=1;l<=r;l++){
    int u=q[l];
    d[++num]=dis[u];
    for(int i=head[u];i;i=edge[i].nxt){
        int v=edge[i].to;
        int w=edge[i].w;
        if(vis[v]||v==fa[u])continue;
        fa[v]=u;
        dis[v]=dis[u]+w;
        q[++r]=v;
    }
    }
    ll ecnt=0;
    sort(d+1,d+num+1);
    int l1=1,r1=num;
    while(l1<r1){
    if(d[l1]+d[r1]<=k){
        ecnt+=r1-l1;
        l1++;
    }else r1--;
    }
    return ecnt;
}
void solve(int u){
    int g=calcg(u);
    vis[g]=1;
    ans+=calc(g,0);
    for(int i=head[g];i;i=edge[i].nxt){
    int v=edge[i].to;
    int w=edge[i].w;
    if(!vis[v])ans-=calc(v,w);
    }
    for(int i=head[g];i;i=edge[i].nxt){
    int v=edge[i].to;
    if(!vis[v])solve(v);
    }
    return;
}
int main(){
    while(scanf("%d%d",&n,&k)!=EOF&&n+k){
    cnt=ans=0;
    memset(head,0,sizeof(head));
    memset(vis,0,sizeof(vis));
    for(int i=1;i<n;i++){
        int u=read();
        int v=read();
        int w=read();
        add(u,v,w);
        add(v,u,w);
    }
    solve(1);
    printf("%lld\n",ans);
    }
    return 0;
}

 

posted @ 2017-12-14 07:59  luyouqi233  阅读(206)  评论(0编辑  收藏  举报