BZOJ2152 聪聪可可 点分治入门
点分治基本思路:
①对每个点, 计算经过它的满足题目要求的链的条数, 所有点的答案加起来刚好不重不漏地统计了整棵树中满足条件的点数
②计算方法:先算出一颗树上经过了根(根找重心,树比较平衡)的满足条件的链的数量(可以分治), 然后打上标记(每次递归下去时如果遇到标记就停止, 相当于将树从这个位置断开分成了多棵子树,分别处理这些子树), 减去仅在子树中(即没有过根)的满足条件的链(递归下去)
本题目统计cnt[i]表示现在为止(可以一边递归同时一边更新)到根距离 %3 = i的点数, 则可以构成的链有cnt[0]*(cnt[0]-1)(两不同的点分别到根拼成一条链)+cnt[1]*cnt[2]*2(两点分别到根拼成一条链,两点可以互换)+cnt[0](以根为端点的链) = cnt[0] * cnt[0] + cnt[1] * cnt[2] * 2(可能会有重复, 比如两点分别到根的路径在没到根时就已经合并到一起, 递归下去时会减掉这种情况)
#include<algorithm> #include<cstring> #include<cstdio> using namespace std; const int maxn = 20007, inf = 1e9; int n, tt, head[maxn], next[maxn << 1], to[maxn << 1], len[maxn << 1]; int siz, tot, root, p, size[maxn], dis[maxn], cnt[3]; bool been[maxn];//成为过重心 void addedge(int u, int v, int l) { next[++tt] = head[u]; to[tt] = v; len[tt] = l % 3; head[u] = tt; } int gcd(int x, int y) { if(!y) return x; return gcd(y, x % y); } void DFS(int k, int last) {//找出根(把重心作为根) size[k] = 1; int mxsz = 0; for(int i = head[k]; i; i = next[i]) { if(been[to[i]] || to[i] == last) continue; DFS(to[i], k); size[k] += size[to[i]]; mxsz = max(mxsz, size[to[i]]); } mxsz = max(mxsz, tot - size[k]); //若以该点为根, 所有先于该点DFS到的点构成一颗size = tot - size[k]的子树 if(!root || siz > mxsz) siz = mxsz, root = k;//更新重心(以重心为根) } void To_Root(int k, int last) {//到根的距离 cnt[dis[k] % 3]++; for(int i = head[k]; i; i = next[i]) { if(been[to[i]] || to[i] == last) continue; dis[to[i]] = dis[k] + len[i]; To_Root(to[i], k); } } int Calculate(int k, int l) {//计算 最短路径 经过某棵树中的节点k 的点对 数量, 其中k到该子树的根的距离为l(不是1是L) dis[k] = l; cnt[0] = cnt[1] = cnt[2] = 0; To_Root(k, 0); return cnt[0] * cnt[0] + cnt[1] * cnt[2] * 2; } int total; void Get_Ans(int x) { been[x] = true; total += Calculate(x, 0); for(int i = head[x]; i; i = next[i]) { if(been[to[i]]) continue; total -= Calculate(to[i], len[i]);//减去该树中满足最短路为3的倍数但最短路径不经过x的点对数 root = 0; tot = size[to[i]];//不知道为什么,这里root = 0不写, 上面找中心时if(!root||siz>mxsz)改成if(siz>mxsz)就会WA, 但是另一道点分治的题这么写却能AC DFS(to[i], x); Get_Ans(root); } } int main() { int u, v, l; scanf("%d", &n); for(int i = 1; i < n; i++) { scanf("%d%d%d", &u, &v, &l); addedge(u, v, l); addedge(v, u, l); } tot = n; DFS(1, 0); Get_Ans(root); int GCD = gcd(total, n * n); printf("%d/%d\n", total / GCD, n * n / GCD); return 0; }