POJ1741 Tree(树分治——点分治)题解
题意:给一棵树,问你最多能找到几个组合(u,v),使得两点距离不超过k。
思路:点分治,复杂度O(nlogn*logn)。看了半天还是有点模糊。
显然,所有满足要求的组合,连接这两个点,他们必然经过他们的最小公共子树。
代码:
#include<set> #include<map> #include<stack> #include<cmath> #include<queue> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> typedef long long ll; const int maxn = 10000 + 10; const int seed = 131; const ll MOD = 1e9 + 7; const int INF = 0x3f3f3f3f; using namespace std; struct Edge{ int v, w, next; }edge[maxn << 1]; int dis[maxn], sz[maxn], maxv[maxn]; //到root距离,子树大小(包括自己),最大孩子 int tot, num, ans, n, k, Max, root, head[maxn]; //root重心 bool vis[maxn]; void addEdge(int u, int v, int w){ edge[tot].v = v; edge[tot].w = w; edge[tot].next = head[u]; head[u] = tot++; } //子树大小 void dfs_sz(int u, int pre){ sz[u] = 1; maxv[u] = 0; for(int i = head[u]; i != -1; i = edge[i].next){ int v = edge[i].v; if(v == pre || vis[v]) continue; dfs_sz(v, u); sz[u] += sz[v]; if(maxv[u] < sz[v]) maxv[u] = sz[v]; } } //找以u为根的子树的重心 void dfs_root(int r, int u, int pre){ maxv[u] = max(maxv[u], sz[r] - sz[u]); //sz[r]-sz[u]是u上面部分的树的尺寸,跟u的最大孩子比,找到最大孩子的最小差值节点 if(maxv[u] < Max){ Max = maxv[u]; root = u; } for(int i = head[u]; i != -1; i = edge[i].next){ int v = edge[i].v; if(v == pre || vis[v]) continue; dfs_root(r, v, u); } } //离重心距离 void dfs_dis(int u, int d, int pre){ dis[num++] = d; for(int i = head[u]; i != -1; i = edge[i].next){ int v = edge[i].v; if(v == pre || vis[v]) continue; dfs_dis(v, d + edge[i].w, u); } } //经过u的满足条件的组合的数量 int cal(int u, int d){ int ret = 0; num = 0; dfs_dis(u, d, -1); sort(dis, dis + num); int i = 0, j = num - 1; while(i < j){ while(dis[i] + dis[j] > k && i < j) j--; ret += j - i; //i到i+1~j满足 i++; } return ret; } void dfs(int u){ Max = n; dfs_sz(u, -1); dfs_root(u, u, -1); ans += cal(root, 0); vis[root] = true; for(int i = head[root]; i != -1; i = edge[i].next){ int v = edge[i].v; if(!vis[v]){ ans -= cal(v, edge[i].w); dfs(v); } } } void init(){ tot = ans = 0; memset(head, -1, sizeof(head)); memset(vis, false, sizeof(vis)); } int main(){ while(scanf("%d%d", &n, &k) && n + k){ init(); int u, v, w; for(int i = 0; i < n - 1; i++){ scanf("%d%d%d", &u, &v, &w); addEdge(u, v ,w); addEdge(v, u, w); } dfs(1); printf("%d\n", ans); } return 0; }