清北学堂模拟赛d4t5 b
分析:一眼树形dp题,就是不会写QAQ.树形dp嘛,定义状态肯定有一维是以i为根的子树,其实这道题只需要这一维就可以了.设f[i]为以i为根的子树中的权值和.先处理子树内部的情况,用一个数组son[i]表示以i为根的子树中,i能走到的节点个数,可以利用son数组和当前点的权值来更新f数组.
处理了每个子树内部的情况,接下来就要合并它们,将每一个根节点作为中间点,算一下中间点权值的贡献,利用乘法原理算出有多少对点对经过中间点,乘一下就ok了.
树形dp的基本状态定义要熟记,有些题目子树内部是互相独立的,可以在子树里面单独计算,最后再合并一下.
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 300010; int n, a[maxn], head[maxn], to[maxn * 2], nextt[maxn * 2], tot = 1, w[maxn * 2]; long long ans, f[maxn], son[maxn]; void add(int x, int y, int z) { w[tot] = z; to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void dfs(int u, int fa, int col) { long long res = 0; f[u] = a[u]; son[u] = 1; bool flag = 1; for (int i = head[u]; i; i = nextt[i]) { int v = to[i]; if (v == fa) continue; dfs(v, u, w[i]); if (col != w[i]) { flag = 0; son[u] += son[v]; f[u] += son[v] * a[u] + f[v]; } res += son[v] * a[u] + f[v]; } ans += res; if (flag) return; for (int i = head[u]; i; i = nextt[i]) { int v1 = to[i]; if (v1 != fa) for (int j = i; j; j = nextt[j]) //防止重复统计,所以j=i而不是j=head[u] { int v2 = to[j]; if (v2 != fa && w[i] != w[j]) ans += son[v1] * f[v2] + son[v2] * f[v1] + a[u] * son[v1] * son[v2]; } } } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i < n; i++) { int x, y, z; scanf("%d%d%d", &x, &y, &z); add(x, y, z); add(y, x, z); } dfs(1, 0, 0); printf("%lld\n", ans); return 0; }