POJ1741:Tree——题解+树分治简要讲解
http://poj.org/problem?id=1741
题目大意:给一棵树,求点对间距离<=k的个数。
————————————————————
以这道题为例记录一下对于树分治的理解。
树分治分为两类,一类是基于点的分治,一类是基于边的分治。
后者与树链剖分很相似,但是一般用不上,这里讲的是前者。
我们一般进行树分治找的点都是这棵树的重心(即子树最大者最小的点),我们每次操作都做与这个点相关的路径,然后删除这个点再重新寻找。
分重心的好处在于我们近似的将树分成了两份,类似于二分,其深度不超过O(logn)(其实有严格证明的,但是我太弱了,不会写)
分完重心的操作大致三种
1.找u,v,其中u,v在重心s的同一棵子树上(这种情况直接忽略,因为看下面的操作我们就能明白我们可以递归的完成这个操作)
2.找u,v,其中u,v在重心s的两棵子树上。
3.找u,查找u到重心s的路径。
我们发现3操作和2操作很相似,我们直接讨论2操作。
显然我们在2操作的路径当中不可避免的要经过s,所以我们从s开始bfs,求出每个点i到s的距离dis[i],我们的路径长度即为dis[u]+dis[v]。
3操作同理只是变成了dis[u]+dis[s],其中dis[s]=0.
这里提供一种简要算法:我们在求完dis之后对我们求的dis排序,这样我们就可以快速的求出点对距离<=k的个数。
但是这样就不可避免的要判重,为什么呢?
废话你这样排不就有可能把1操作的一部分点对先算了一遍,这样明显会导致答案变大。
那怎么办呢?我们对于每一棵子树,再删掉我们通过2操作得到的点对即可。
(现将s删掉,s的儿子dis[u]不变的情况下以u为起点bfs求点对,则这些点对就是在同一棵子树当中被计算的重复的点对,减去即可。)
#include<cmath> #include<cstdio> #include<queue> #include<cctype> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; const int N=10001; inline int read(){ int X=0,w=0; char ch=0; while(!isdigit(ch)) {w|=ch=='-';ch=getchar();} while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar(); return w?-X:X; } struct node{ int w; int to; int nxt; }edge[N*2]; int cnt,n,k,head[N],q[N],dis[N],size[N],son[N],d[N],fa[N]; ll ans; bool vis[N]; void add(int u,int v,int w){ cnt++; edge[cnt].to=v; edge[cnt].w=w; edge[cnt].nxt=head[u]; head[u]=cnt; return; } int calcg(int st){ int r=0,g,maxn=n; q[++r]=st; fa[st]=0; for(int l=1;l<=r;l++){ int u=q[l]; size[u]=1; son[u]=0; for(int i=head[u];i;i=edge[i].nxt){ int v=edge[i].to; if(vis[v]||v==fa[u])continue; fa[v]=u; q[++r]=v; } } for(int l=r;l>=1;l--){ int u=q[l],v=fa[u]; if(r-size[u]>son[u])son[u]=r-size[u]; if(son[u]<maxn)g=u,maxn=son[u]; if(!v)break; size[v]+=size[u]; if(size[u]>son[v])son[v]=size[u]; } return g; } inline ll calc(int st,int L){ int r=0,num=0; q[++r]=st; dis[st]=L; fa[st]=0; for(int l=1;l<=r;l++){ int u=q[l]; d[++num]=dis[u]; for(int i=head[u];i;i=edge[i].nxt){ int v=edge[i].to; int w=edge[i].w; if(vis[v]||v==fa[u])continue; fa[v]=u; dis[v]=dis[u]+w; q[++r]=v; } } ll ecnt=0; sort(d+1,d+num+1); int l1=1,r1=num; while(l1<r1){ if(d[l1]+d[r1]<=k){ ecnt+=r1-l1; l1++; }else r1--; } return ecnt; } void solve(int u){ int g=calcg(u); vis[g]=1; ans+=calc(g,0); for(int i=head[g];i;i=edge[i].nxt){ int v=edge[i].to; int w=edge[i].w; if(!vis[v])ans-=calc(v,w); } for(int i=head[g];i;i=edge[i].nxt){ int v=edge[i].to; if(!vis[v])solve(v); } return; } int main(){ while(scanf("%d%d",&n,&k)!=EOF&&n+k){ cnt=ans=0; memset(head,0,sizeof(head)); memset(vis,0,sizeof(vis)); for(int i=1;i<n;i++){ int u=read(); int v=read(); int w=read(); add(u,v,w); add(v,u,w); } solve(1); printf("%lld\n",ans); } return 0; }