[Bzoj2152]聪聪可可
题目链接:https://www.lydsy.com/JudgeOnline/problem.php?id=2152
问题为求树上有多少个点对路径之和为3的倍数,将每条边对3取余,然后点分治。每次求该子树下到根节点边权和相加分别为0,1,2的子节点个数,分别记为O[0],O[1],O[2]。
则该子树下满足路径和为3的倍数的点对为O[0]*O[0]+O[1]*O[2]*2
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #include<iostream> 5 using namespace std; 6 typedef long long ll; 7 typedef unsigned long long ull; 8 const int maxn = 2e5 + 10; 9 const int inf = 1e9 + 7; 10 struct node { 11 int s, e, w, next; 12 }edge[maxn]; 13 int head[maxn], len; 14 void init() { 15 memset(head, -1, sizeof(head)); 16 len = 0; 17 } 18 void add(int s, int e, int w) { 19 edge[len].e = e; 20 edge[len].w = w; 21 edge[len].next = head[s]; 22 head[s] = len++; 23 } 24 int gcd(int a, int b) { 25 return (b == 0 ? a : gcd(b, a%b)); 26 } 27 int n, m; 28 int root, lens, sum, ans; 29 int d[maxn], o[maxn], vis[maxn], f[maxn], son[maxn]; 30 void getroot(int x, int fa) { 31 son[x] = 1, f[x] = 0; 32 for (int i = head[x]; i != -1; i = edge[i].next) { 33 int y = edge[i].e; 34 if (y == fa || vis[y])continue; 35 getroot(y, x); 36 son[x] += son[y]; 37 f[x] = max(f[x], son[y]); 38 } 39 f[x] = max(f[x], sum - son[x]); 40 if (f[x] < f[root])root = x; 41 } 42 void getd(int x, int fa) { 43 o[d[x]]++; 44 for (int i = head[x]; i != -1; i = edge[i].next) { 45 int y = edge[i].e; 46 if (y == fa || vis[y])continue; 47 d[y] = (d[x] + edge[i].w) % 3; 48 getd(y, x); 49 } 50 } 51 int cal(int x, int val, int add) { 52 o[0] = o[1] = o[2] = 0; 53 d[x] = val; 54 getd(x, 0); 55 return o[0] * o[0] + o[1] * o[2] * 2; 56 } 57 void solve(int x) { 58 ans += cal(x, 0, 1); 59 vis[x] = 1; 60 for (int i = head[x]; i != -1; i = edge[i].next) { 61 int y = edge[i].e; 62 if (vis[y])continue; 63 ans -= cal(y, edge[i].w, -1); 64 sum = son[y]; 65 root = 0; 66 getroot(y, 0); 67 solve(root); 68 } 69 } 70 int main() { 71 scanf("%d", &n); 72 init(); 73 root = 0, ans = 0; 74 memset(vis, 0, sizeof(vis)); 75 for (int i = 1; i < n; i++) { 76 int x, y, z; 77 scanf("%d%d%d", &x, &y, &z); 78 z %= 3; 79 add(x, y, z); 80 add(y, x, z); 81 } 82 f[0] = inf; 83 sum = n; 84 getroot(1, 0); 85 solve(root); 86 int mu = n * n, zi = ans; 87 int k = gcd(mu, zi); 88 printf("%d/%d\n", zi / k, mu / k); 89 //system("pause"); 90 }