[POJ1741]Tree
题目大意:
给你一棵带权树,求出树中距离$\leq k$的点对个数。
思路:
运用树上分治的思想,每次找出树的重心,考虑以下三种情况:
1.两个结点在不同子树内,且距离$\leq k$,则算入答案中;
2.两个结点距离$\leq k$,但属于同一棵子树中,需要被算入答案中,但考虑到以后会被子树的重心重新计算,故在这里忽略;
3.两个结点距离$>k$,显然需要忽略。
简而言之,就是每次统计最短路径经过重心的、距离$\leq k$的点对个数。
我们可以每次先DP求出这棵子树的重心,再以这个重心为根,遍历整棵子树,将得到的离重心的距离存入一个数组中,然后枚举子树中的每个点对,将距离和$\leq k$的计入答案。
考虑到要去除同一棵子树中的点对,我们需要在记录距离的同时,要记录每个结点所属的子树的编号,然后枚举的时候判断两个点是否属于统一子树即可。
但是这样还是会TLE。
考虑每次使用两个数组$a$和$b$,$a$存储当前子树各个结点的距离,$b$存储之前所有结点的距离,统计答案时只需要对$a$排序,然后枚举$b$中每个元素$b_i$,在$a$中二分查找小于等于$k-b_i$的元素个数即可。
树上分治是$O(\log n)$的,排序是$O(n\log n)$的,因此总的时间复杂度是$O(n\log^2 n)$。
1 #include<cstdio> 2 #include<cctype> 3 #include<vector> 4 #include<cstring> 5 #include<algorithm> 6 inline int getint() { 7 char ch; 8 while(!isdigit(ch=getchar())); 9 int x=ch^'0'; 10 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 11 return x; 12 } 13 const int inf=0x7fffffff; 14 const int V=10001; 15 struct Edge { 16 int to,w; 17 Edge(const int to,const int w) { 18 this->to=to; 19 this->w=w; 20 } 21 }; 22 std::vector<Edge> e[V]; 23 inline void add_edge(const int u,const int v,const int w) { 24 e[u].push_back(Edge(v,w)); 25 } 26 int size[V]; 27 bool vis[V]; 28 int min_subtree_size,centroid,tree_size; 29 void get_centroid(const int x,const int par) { 30 size[x]=1; 31 int max=0; 32 for(unsigned i=0;i<e[x].size();i++) { 33 int &y=e[x][i].to; 34 if(vis[y]||y==par) continue; 35 get_centroid(y,x); 36 size[x]+=size[y]; 37 max=std::max(max,size[y]); 38 } 39 max=std::max(max,tree_size-size[x]); 40 if(max<min_subtree_size) { 41 min_subtree_size=max; 42 centroid=x; 43 } 44 } 45 int k; 46 std::vector<int> dis,tdis; 47 void get_dist(const int x,const int par,const int d) { 48 if(d<=k) tdis.push_back(d); 49 size[x]=1; 50 for(unsigned i=0;i<e[x].size();i++) { 51 int &y=e[x][i].to; 52 if(vis[y]||y==par) continue; 53 get_dist(y,x,d+e[x][i].w); 54 size[x]+=size[y]; 55 } 56 } 57 int ans=0; 58 inline void solve(const int x,const int sz) { 59 tree_size=sz; 60 min_subtree_size=inf; 61 get_centroid(x,0); 62 vis[centroid]=true; 63 /*dis.clear(); 64 dis.push_back(Vertex(0,centroid)); 65 for(unsigned i=0;i<e[centroid].size();i++) { 66 int &y=e[centroid][i].to; 67 if(vis[y]) continue; 68 get_dist(y,centroid,e[centroid][i].w,y); 69 } 70 std::sort(dis.begin(),dis.end()); 71 for(unsigned i=0;i<dis.size();i++) { 72 for(unsigned j=i+1;j<dis.size();j++) { 73 if(dis[i].d+dis[j].d>k) break; 74 if(dis[i].root!=dis[j].root) ans++; 75 } 76 }*/ 77 dis.clear(); 78 dis.push_back(0); 79 for(unsigned i=0;i<e[centroid].size();i++) { 80 int &y=e[centroid][i].to; 81 if(vis[y]) continue; 82 tdis.clear(); 83 get_dist(y,centroid,e[centroid][i].w); 84 std::sort(tdis.begin(),tdis.end()); 85 for(unsigned i=0;i<dis.size();i++) { 86 ans+=std::upper_bound(tdis.begin(),tdis.end(),k-dis[i])-tdis.begin(); 87 } 88 dis.insert(dis.end(),tdis.begin(),tdis.end()); 89 } 90 int cur=centroid; 91 for(unsigned i=0;i<e[cur].size();i++) { 92 int &y=e[cur][i].to; 93 if(vis[y]) continue; 94 solve(y,size[y]); 95 } 96 } 97 inline void init() { 98 ans=0; 99 memset(vis,0,sizeof vis); 100 for(int i=0;i<=V;i++) e[i].clear(); 101 } 102 int main() { 103 for(;;) { 104 int n=getint(); 105 k=getint(); 106 if(!n&&!k) return 0; 107 init(); 108 for(int i=1;i<n;i++) { 109 int u=getint(),v=getint(),w=getint(); 110 add_edge(u,v,w); 111 add_edge(v,u,w); 112 } 113 solve(1,n); 114 printf("%d\n",ans); 115 } 116 }