POJ-1741 Tree (树上点分治)
题目大意:一棵带边权无根树,边权代表距离,求距离小于等于k的点对儿数。
题目分析:这两个点之间的路径只有两种可能,要么经过根节点,要么在一棵子树内。定义depth(i)表示点 i 到根节点的距离,belong(i)表示 i 所属的子树。如果路径经过根节点,那么满足depth(i)+depth(j)<=k并且belong(i)<>belong(j)的(i,j)为一个点对儿,如果在子树内,递归到子树即可。
总的过程就变成了这样的:
1、求出所有的depth;
2、求出满足depth(i)+depth(j)<=k并且belong(i)<>belong(j)的点对数;
3、递归到子树;
这道题的实现起来技巧性比较强:
1、在找点对儿(i,j)时,先将所有的depth排好序(快排的复杂度是O(NlogN)),然后就可以用O(N)的复杂度找出满足depth(i)+depth(j)<=k的点对儿数,不过这样找出的点对儿也包含belong(i)=belong(j)的,所以要减掉满足这一部分的点对儿数。
2、递归进行到每一棵子树时,都要以子树的重心为根节点开始进行上述的过程。这是因为要保证无论是什么样的树,都能以O(logN)的时间复杂度完成任务。
这样,总的时间复杂度为O(Nlog2N)。
# include<iostream> # include<cstdio> # include<cstring> # include<vector> # include<queue> # include<list> # include<set> # include<map> # include<string> # include<cmath> # include<cstdlib> # include<algorithm> using namespace std; # define LL long long const int N=1005; const int INF=1000000000; struct Edge { int to,w,nxt; }; Edge e[N*20]; int n,m,cnt,mi; int head[N*10]; int root,ans; int ms[N*10]; int size[N*10]; int depth[N*10]; bool del[N*10]; void add(int u,int v,int w) { e[cnt].to=v; e[cnt].w=w; e[cnt].nxt=head[u]; head[u]=cnt++; } void init() { ans=cnt=0; int a,b,c; memset(head,-1,sizeof(head)); memset(del,false,sizeof(del)); for(int i=1;i<n;++i){ scanf("%d%d%d",&a,&b,&c); add(a,b,c); add(b,a,c); } } void getSize(int u,int fa) { size[u]=1; ms[u]=0; for(int i=head[u];i!=-1;i=e[i].nxt){ int v=e[i].to; if(v==fa||del[v]) continue; getSize(v,u); size[u]+=size[v]; if(size[v]>ms[u]) ms[u]=size[v]; } } void getRoot(int r,int u,int fa) { ms[u]=max(ms[u],size[r]-size[u]); if(mi>ms[u]){ mi=ms[u]; root=u; } for(int i=head[u];i!=-1;i=e[i].nxt){ int v=e[i].to; if(v==fa||del[v]) continue; getRoot(r,v,u); } } void getDep(int u,int dep,int fa) { depth[cnt++]=dep; for(int i=head[u];i!=-1;i=e[i].nxt){ int v=e[i].to; if(v==fa||del[v]) continue; getDep(v,dep+e[i].w,u); } } int cal(int u,int d) { cnt=0; getDep(u,d,-1); sort(depth,depth+cnt); int l=0,r=cnt-1,res=0; while(l<r){ while(l<r&&depth[l]+depth[r]>m) --r; res+=r-l; ++l; } return res; } void dfs(int u) { mi=n; getSize(u,-1); getRoot(u,u,-1); ans+=cal(root,0); del[root]=true; for(int i=head[root];i!=-1;i=e[i].nxt){ int v=e[i].to; if(del[v]) continue; ans-=cal(v,e[i].w); dfs(v); } } void solve() { dfs(1); printf("%d\n",ans); } int main() { while(~scanf("%d%d",&n,&m)&&(n+m)) { init(); solve(); } return 0; }