[USACO10MAR] 伟大的奶牛聚集
题目类型:树形\(dp\)
传送门:>Here<
题意:给出一棵有边权树,每个节点有\(c[i]\)个人。现在要求所有人聚集到一个点去,代价为每个人走的距离之和。问选哪个点?
解题思路
暴力做法:枚举聚集点,再\(O(n)\)计算每个点到它的距离,还得用\(lca\)求,复杂度\(O(n^2logn)\)
暴力做法2:我们考虑\(O(n)\)维护一个数组\(t[i]\),表示节点\(i\)的子树内所有人到\(i\)的路程之和。易知根节点的\(t\)值就是聚集到根节点时的答案。转\(n\)次根重新遍历,打擂,复杂度\(O(n^2)\)
其实正解就是对暴力做法的一个改进。暴力做法之所以慢,是因为每到一个新的点\(t\)都要重新计算。没有充分利用历史信息。
我们发现对于所有根节点的子节点,不过是子树外的节点多走了这一条边,子树内的节点少走了这一条边。因此就可以完成\(O(1)\)转移了。$$dp[v] = dp[u] + (TotSize - size[v]) * cost[i] - size[v] * cost[i];$$
反思
这题不太像普通的树形\(dp\),一般的树形\(dp\)根节点的值都由子树转移来。这道题却让子树的值由根节点转移。逆向思维。
我做这道题的盲点在于我一直在考虑\(dp[i]\)表示所有节点到子树\(i\)内的一个节点的最小值。事实上子树内这个概念搞得非常玄也非常难搞,干脆定在\(i\)上有时候是一种更好的思路。如果我能够想到直接定在\(i\)上,也就不难想出转移了。
Code
inf开大,自己很快就调出来了。调试嘛,可能的,老犯的错误也就那么几种。
/*By DennyQi 2018*/
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define int ll
const int MAXN = 100010;
const int MAXM = 200010;
const int INF = 106110956700000;
inline int Max(const int a, const int b){ return (a > b) ? a : b; }
inline int Min(const int a, const int b){ return (a < b) ? a : b; }
inline int read(){
int x = 0; int w = 1; register char c = getchar();
for(; c ^ '-' && (c < '0' || c > '9'); c = getchar());
if(c == '-') w = -1, c = getchar();
for(; c >= '0' && c <= '9'; c = getchar()) x = (x<<3) + (x<<1) + c - '0'; return x * w;
}
int N,x,y,z,TotSize;
int c[MAXN],t[MAXN],size[MAXN],dp[MAXN];
int first[MAXN],nxt[MAXM],to[MAXM],cost[MAXM],cnt;
inline void add(int u ,int v, int w){
to[++cnt] = v, cost[cnt] = w, nxt[cnt] = first[u], first[u] = cnt;
}
void Dfs(int u, int Fa){
int v;
size[u] = c[u];
for(int i = first[u]; i; i = nxt[i]){
if((v = to[i]) == Fa) continue;
Dfs(v, u);
size[u] += size[v];
t[u] += t[v] + size[v] * cost[i];
}
}
void Dp(int u, int Fa){
int v;
for(int i = first[u]; i; i = nxt[i]){
if((v = to[i]) == Fa) continue;
dp[v] = dp[u] + (TotSize - size[v]) * cost[i] - size[v] * cost[i];
Dp(v, u);
}
}
#undef int
int main(){
#define int ll
N = read();
for(int i = 1; i <= N; ++i){
c[i] = read();
TotSize += c[i];
}
for(int i = 1; i < N; ++i){
x = read(), y = read(), z = read();
add(x, y, z);
add(y, x, z);
}
Dfs(1, 0);
dp[1] = t[1];
Dp(1, 0);
int ans(INF);
for(int i= 1; i <= N; ++i){
ans = Min(ans, dp[i]);
}
printf("%lld", ans);
return 0;
}