【ybt金牌导航5-2-2】【luogu P2634】聪聪可可
聪聪可可
题目链接:ybt金牌导航5-2-2 / luogu P2634
题目大意
给你一个树,边有边权,问你有多少个点对,使得它们两个的距离是 3 的倍数。
输出它与所有对数的比值。
思路
容易想到用点分治,找有多少链长度模三余数是 \(0,1,2\) 的,然后 \(0\) 自己跟自己组,\(1\) 跟 \(2\) 组。
然后记得 \((a,b)\) 跟 \((b,a)\) 算不同的。
写完 T 了才发现之前自己的写法找到的不是重心,然后就去改了一下。
现在的这个速度可以,但之前的还是很慢。
代码
#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
struct node {
int x, to, nxt;
}e[80001];
int n, x, y, z, le[40001], KK, root;
int size[40001], num[5], f[40001], dis[40001];
bool fin[40001];
int ans, sum;
void add(int x, int y, int z) {
e[++KK] = (node){z, y, le[x]}; le[x] = KK;
e[++KK] = (node){z, x, le[y]}; le[y] = KK;
}
void dfs_find_root(int now, int father) {
int maxn = 0;
size[now] = 1;
f[now] = 0;
for (int i = le[now]; i; i = e[i].nxt)
if (!fin[e[i].to] && e[i].to != father) {
dfs_find_root(e[i].to, now);
size[now] += size[e[i].to];
f[now] = max(f[now], size[e[i].to]);
}
f[now] = max(f[now], sum - size[now]);
if (f[now] < f[root]) {
root = now;
}
}
void find_root(int now) {
root = 0;
dfs_find_root(now, 0);
}
void get_dis(int now, int father) {
num[dis[now]]++;
for (int i = le[now]; i; i = e[i].nxt)
if (!fin[e[i].to] && e[i].to != father) {
dis[e[i].to] = (dis[now] + e[i].x) % 3;
get_dis(e[i].to, now);
}
}
int work(int now, int diss) {
num[0] = num[1] = num[2] = 0;
dis[now] = diss;
get_dis(now, 0);
return num[0] * num[0] + num[1] * num[2] * 2;//记得 (a,b) 和 (b,a) 是算两次,而且是可以 (a,a) 算一次的
}
int get_size(int now, int father) {
int re = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !fin[e[i].to])
re += get_size(e[i].to, now);
return re;
}
void slove(int now) {
fin[root] = 1;
ans += work(root, 0);
for (int i = le[root]; i; i = e[i].nxt)
if (!fin[e[i].to]) {
ans -= work(e[i].to, e[i].x);
sum = get_size(e[i].to, 0);
find_root(e[i].to);
slove(root);
}
}
int gcd(int x, int y) {
if (!y) return x;
return gcd(y, x % y);
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d %d %d", &x, &y, &z);
z %= 3;
add(x, y, z);
}
f[0] = 2147483647;
sum = n;
find_root(1);
slove(root);
int x = ans, y = n * n;
int GCD = gcd(x, y);
x /= GCD; y /= GCD;
printf("%d/%d", x, y);
return 0;
}