【题解】P3959 - [NOIP2017 提高组] 宝藏
题目大意
给出一个包含 \(n\) 个结点,\(m\) 条无向边的无向图。现在任意选定一初始结点,拓展 \(n - 1\) 条边求出原图的一棵生成树。假设从初始结点 \(u\) 到达当前结点 \(v\) 需要经过 \(K\) 个结点(包含初始结点但不包含当前结点),且 \(v\) 与上一个结点之间的无向边长度为 \(L\),则拓展这条无向边的代价为 \(L \times K\)。试求出最小的代价之和,数据保证原图连通但可能有重边。
\(1 \leq n \leq 12, 0 \leq m \leq 10^3, v \leq 5 \times 10^5\)
解题思路
观察数据范围,结合题目需要求最优解可知这道题应该用 状压 \(dp\) 来解决。题意为求出原图中的一棵生成树,且我们计算代价的时候需要当前已经经过的结点数量,故而我们应该在状态中表示出拓展到当前结点时已经经过的点数。因为最终的形态是以初始结点为根的生成树,所以经过的结点数量可以对应成在该生成树中的高度。由此有 \(dp_{i, j}\) 表示 当前生成树中最后拓展的若干结点深度 为 \(j\),且 已经拓展到的结点集合 为 \(i\) 的最小代价和。
考虑状态转移方程。对于状态 \(dp_{i, j}\),我们可以从任意一个满足 \(k\) 是 \(j\) 的子集且 \(k\) 可以通过原图中的边合法地拓展到 \(j\) 的状态 \(dp_{i - 1, k}\) 转移过来,设当前拓展的无向边长度为 \(w\),则这次拓展的代价为 \(w \times i\)。
初始时我们可以任意选中初始结点,因此边界条件为 \(dp_{0, (1 << i)} = 0\)
因此可以总结算法流程:
-
读入,对于每个结点集合 \(S\),预处理出它可以通过连接原图中的无向边拓展到的结点集合 \(g_S\)
-
状压 \(dp\)
-
因为树根的深度设为 \(0\),所以答案为全集在 \([0, n)\) 深度中的最小值,即 \(\min\limits_{i = 0}^n dp_{(1 << n) - 1, i}\)
注意实现如果暴力枚举子集会超时,这里有一个枚举子集的技巧:枚举 \(i\) 的子集可以通过 for(int s = i - 1; s; s = (s - 1) & i)
来实现。可以证明这样枚举子集的时间复杂度是 \(O(3^n)\) 的。
算法总时间复杂度为 \(O(3^n \times n^2)\),时间复杂度在于枚举状态。
参考代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#define I_love using
#define Yuezheng namespace
#define Ling std
#define int long long
I_love Yuezheng Ling;
const int maxn = 15;
const int maxs = (1 << maxn);
const int inf = 0x3f3f3f3f;
int n, m;
int g[maxs];
int dis[maxn][maxn], dp[maxs][maxn];
signed main() {
// freopen("P3959_2.in", "r", stdin);
int u, v, w, ans = inf;
memset(dp, 0x3f, sizeof(dp));
memset(dis, 0x3f, sizeof(dis));
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= m; i++) {
scanf("%lld%lld%lld", &u, &v, &w);
u--, v--;
dis[u][v] = dis[v][u] = min(dis[u][v], w);
}
for (int i = 1; i < (1 << n); i++) {
for (int j = 0; j < n; j++) {
if (i & (1 << j)) {
dis[j][j] = 0;
for (int k = 0; k < n; k++) {
if (dis[j][k] != inf) {
g[i] |= (1 << k);
}
}
}
}
}
for (int i = 0; i < n; i++) {
dp[1 << i][0] = 0;
}
for (int i = 2; i < (1 << n); i++) {
for (int j = i - 1; j; j = (j - 1) & i) {
if ((g[j] | i) == g[j]) {
int k = i ^ j, sum = 0, minn;
for (int u = 0; u < n; u++) {
if (k & (1 << u)) {
minn = inf;
for (int v = 0; v < n; v++) {
if (j & (1 << v)) {
minn = min(minn, dis[v][u]);
}
}
sum += minn;
}
}
for (int d = 1; d < n; d++) {
if (dp[j][d - 1] != inf) {
dp[i][d] = min(dp[i][d], dp[j][d - 1] + sum * d);
}
}
}
}
}
for (int i = 0; i < n; i++) {
ans = min(ans, dp[(1 << n) - 1][i]);
}
printf("%lld\n", ans);
return 0;
}