洛谷P1364 医院设置 题解 搜索/树形DP(暂时存在一点问题)
题目链接:https://www.luogu.com.cn/problem/P1364
因为题目数据比较小所有可以用搜索解决。
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 110;
vector<int> g[maxn];
int n, v[maxn], ans = INT_MAX;
int dfs(int u, int p, int d) {
int res = v[u] * d;
int sz = g[u].size();
for (int i = 0; i < sz; i ++) {
int v = g[u][i];
if (v == p) continue;
res += dfs(v, u, d+1);
}
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) {
int a, b;
cin >> v[i] >> a >> b;
if (a) {
g[i].push_back(a);
g[a].push_back(i);
}
if (b) {
g[i].push_back(b);
g[b].push_back(i);
}
}
for (int i = 1; i <= n; i ++) ans = min(ans, dfs(i, -1, 0));
cout << ans << endl;
return 0;
}
时间复杂度 \(O(n^2)\)。
但是其实这道题目也可以用树形DP来解决。
我们这里设:
- \(v[i]\) 表示点 \(i\) 的人数;
- \(s[i]\) 表示以点 \(i\) 为根节点的总人数;
- \(f[i]\) 表示以点 \(i\) 为根节点的子树中的所有节点到点 \(i\) 的距离之和;
- \(g[i]\) 表示所有节点到点 \(i\) 的距离之和。
则,可以推导状态转移方程:
\[s[u] = v[u] + \sum_v {s[v]}
\]
\[f[u] = \sum_v f[v] + s[v]
\]
其中, \(v\) 是 \(u\) 的任意子节点。
另外,我们这里是不把他当做一棵二叉树的,就当做一棵普通的树,并且我假设点 \(1\) 就是根节点。
那么此时:
当 \(u = 1\) 时, \(f[u]\) 就是所有节点到点 \(1\) 的距离,即:
\[g[u] = f[u]
\]
当 \(u \ne 1\) 时,(我们设 \(u\) 的父节点为 \(p\))所有节点到点 \(u\) 的距离应该是
\[f[u] + g[p] - f[u] - s[u] + s[p] - s[u] = g[p] + s[p] - 2 \times s[u]
\]
咦,好像没算对。有大神算出来了麻烦留言一下,我算错了好像(但是我又懒不想算了囧~)
错误实现代码如下(有时间再修正):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 110;
vector<int> G[maxn];
int n, v[maxn], p[maxn], s[maxn], f[maxn], g[maxn], ans = INT_MAX;
void dfs(int u) {
s[u] = v[u];
int sz = G[u].size();
for (int i = 0; i < sz; i ++) {
int v = G[u][i];
if (v == p[u]) continue;
p[v] = u;
dfs(v);
s[u] += s[v];
f[u] += f[v] + s[v];
}
}
void dfs2(int u) { // 求g[u]
if (u == 1) g[u] = f[u];
else g[u] = g[p[u]] + s[p[u]] - 2 * s[u];
int sz = G[u].size();
for (int i = 0; i < sz; i ++) {
int v = G[u][i];
if (v == p[u]) continue;
dfs2(v);
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) {
int a, b;
cin >> v[i] >> a >> b;
if (a) {
G[i].push_back(a);
G[a].push_back(i);
}
if (b) {
G[i].push_back(b);
G[b].push_back(i);
}
}
dfs(1);
dfs2(1);
for (int u = 1; u <= n; u ++) {
ans = min(ans, g[u]);
}
cout << ans << endl;
return 0;
}