BZOJ 2152:聪聪可可(树上点分治)
题意
中文题意。
思路
和上一题类似,只不过cal()函数需要发生变化。
题目中要求是3的倍数,那么可以想到 (a + b) % 3 == 0
和 (a % 3 + b % 3) % 3 == 0
是一样的,因此,我们只要在每次计算路径长度的时候,把 dis[u]%3
放在一个桶里面,然后就可以转化为,一个简单的计数问题了。
tong[0]
对于答案的贡献:就像题目中一共有n^2个点对一样,一开始包括根结点本身1个点,有多少条路径,就有多少个点,因此是 tong[0]^2
。
tong[1] 和 tong[2]
对于答案的贡献:每个长度为1的路径,都可以和每个长度为2的路径匹配,而且因为是点对,(2,3)和(3,2)算两种,所以乘2。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e4 + 10;
const int INF = 0x3f3f3f3f;
struct Edge {
int v, nxt, w;
} edge[N*2];
int dis[N], son[N], f[N], vis[N], tot, head[N], tong[3], root, sum, ans;
void Add(int u, int v, int w) {
edge[tot] = (Edge) { v, head[u], w }; head[u] = tot++;
edge[tot] = (Edge) { u, head[v], w }; head[v] = tot++;
}
void getroot(int u, int fa) {
son[u] = 1; f[u] = 0;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v;
if(vis[v] || v == fa) continue;
getroot(v, u);
son[u] += son[v];
f[u] = max(f[u], son[v]);
}
f[u] = max(f[u], sum - son[u]);
if(f[u] < f[root]) root = u;
}
void getdeep(int u, int fa) {
tong[dis[u]]++;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v, w = edge[i].w;
if(vis[v] || v == fa) continue;
dis[v] = (dis[u] + w) % 3;
getdeep(v, u);
}
}
int cal(int u, int now) {
dis[u] = now;
memset(tong, 0, sizeof(tong));
getdeep(u, 0);
// printf("tong %d : %d, %d, %d\n\n", u, tong[0], tong[1], tong[2]);
// 就像题目中一共有n^2个点对一样,一开始包括根结点本身1个点,有多少条路径,就有多少个点,因此是tong[0]^2
int res1 = tong[0] * tong[0];
// 对于每个长度为1的路径,都可以和每个长度为2的路径匹配,而且因为是点对,(2,3)和(3,2)算两种,所以乘2
int res2 = tong[1] * tong[2] * 2;
return res1 + res2;
}
int work(int u) {
// int now = cal(u, 0);
ans += cal(u, 0);
// ans += now;
// printf("work %d : %d\n\n", u, now);
vis[u] = 1;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v, w = edge[i].w;
if(vis[v]) continue;
int now = cal(v, w);
// printf("delete %d -> %d : %d\n\n", u, v, cal(v, w));
ans -= cal(v, w);
sum = son[v];
getroot(v, root = 0);
// printf("root : %d\n\n", root);
work(root);
}
}
int main() {
int n;
while(~scanf("%d", &n)) {
memset(head, -1, sizeof(head)); tot = 0;
for(int i = 1; i < n; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
Add(u, v, w % 3);
}
sum = n, ans = root = 0, f[0] = INF;
getroot(1, root);
work(root);
// ans += n;
// printf("ans : %d\n", ans);
int tol = n * n;
int g = __gcd(ans, tol);
printf("%d/%d\n", ans / g, tol / g);
}
return 0;
}
/*
5
1 2 1
1 3 2
1 4 1
2 5 3
8
1 2 1
2 5 3
1 4 1
4 6 2
1 3 2
3 7 2
7 8 3
*/