三角果计数
题目描述
给一个n 个节点的树,三角果定义为一个包含 \(3\) 个节点的集合,且他们两两之间的最短路长度 \(a, b, c\) 能够构成一个三角形。
计算这棵树上有多少个不同的三角果。
两个三角果不同当且仅当 \({a_1, b_1, c_1} != {a_2, b_2, c_2}\)
思路
简单画一下会发现当三个点在一条线上时一定会发生两边之和等于第三遍的情况。其他的情况一定是满足条件的.
一种较为容易想到的做法是,用 \(dfs\) 进行树的遍历,枚举到每个节点时我们算出 不包含这个节点的所有方案,即把这个点
\(root\) 当做树根,去计算所有的合法方案。 由于树上两点 \(u,v\) 的最短路径一定是 \(u-lca(u,v)-v\) 因此我们只需要
在 \(root\) 的三个不同的子树任选一个点,即可构成一个合法方案。 因为在某个子树中选择多个节点会在 \(dfs\) 到子树的时候会求出,这样可以避免重复计算。 (并且在此题中边权没有用)
(由于每次以\(root\)为根的计算方案数的时候,三个点的 \(lca\) 均为 \(root\),以此进行区分可以得出不重不漏)
上述问题可以等价为以每个点为根,在其所有子树中选择三个不同的子树,每个子树选择一个节点的方案数之和
即
立方和公式中
因此我们可以转化一下等式
CODE
#include <bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i ++ )
using namespace std;
typedef long long LL;
typedef pair<int, int> PII ;
template <typename T> void chkmax(T &x, T y) { x = max(x, y); }
template <typename T> void chkmin(T &x, T y) { x = min(x, y); }
const int N = 1e5 + 10, M = 2e5 + 10;
int h[N], ne[M], e[M], idx;
int sz[N];
int n;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
LL dfs(int u, int fa) {
sz[u] = 1;
LL ans = 0, sum1 = 0, sum2 = 0, sum = n - 1;
vector<int> son;
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(j == fa) continue;
ans += dfs(j, u);
sz[u] += sz[j];
son.push_back(sz[j]);
}
if(n - sz[u]) {
son.push_back(n - sz[u]);
}
if(son.size() < 3) return ans;
for(auto x: son) {
sum2 += 3LL * x * x * (sum - x);
sum2 += 1LL * x * x * x;
}
ans += (sum * sum * sum - sum1 - sum2) / 6;
return ans;
}
int main() {
scanf("%d", &n);
rep(i, 0, n) h[i] = -1;
rep(i, 1, n - 1) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
add(u, v); add(v, u);
}
dfs(1, -1);
printf("%lld", ans);
return 0;
}
对此问题的第二种解决方案
我们再对问题进行一下抽象
对于每个点为 \(root\) , 我们有 \(n\) 个子树,每个子树的节点个数是 \(a_i\)
我们一定不能选择重复的子树,因此我们可以枚举 \(a_j\),进行计算,这样我们左边可选方案数就是
\(\sum_{i=1}^{j-1}a_i\) 右边可选方案数就是 \(\sum_{k=j+1}^{n}a_k\)
我们在进行 \(dfs\) 遍历 \(root\) 的所有儿子的时候,可以维护一个前缀计算的和 \(pre\),很显然 \(pre = sz[root]\),\(sz[root]\) 是当前已经枚举过的所有子树的 \(a_i\) 的和,而后缀 \(sam = n - 1 - pre - sz[j]\) , j 是当前节点.
CODE
#include <bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i ++ )
using namespace std;
typedef long long LL;
typedef pair<int, int> PII ;
template <typename T> void chkmax(T &x, T y) { x = max(x, y); }
template <typename T> void chkmin(T &x, T y) { x = min(x, y); }
const int N = 1e5 + 10, M = 2e5 + 10;
int h[N], ne[M], e[M], idx;
int sz[N];
int n;
LL ans;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dfs(int u, int fa) {
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(j == fa) continue;
dfs(j, u);
ans += 1LL * sz[u] * sz[j] * (n - sz[u] - sz[j] - 1);
sz[u] += sz[j];
}
sz[u] ++;
}
int main() {
scanf("%d", &n);
rep(i, 0, n) h[i] = -1;
rep(i, 1, n - 1) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
add(u, v); add(v, u);
}
dfs(1, -1);
printf("%lld", ans);
return 0;
}