2023.08.29T4 - light - solution
light
Problem
给定一棵 \(n\) 个节点的树,并给定每个点的点权 \(a_i\)。
定义一次操作为:
- 选择一个未被删除的节点 \(u\),\(w \leftarrow a_u\),\(\forall v, v \text{ is connected to } u, a_v \leftarrow a_u\),删除节点 \(u\)。
求所有 \(n!\) 种操作方案中,结果 \(w\) 的最大值。
Solution
首先转化一波题意:把原操作转化为 给边定向,则若一个点 \(u\) 能到达 \(c\) 个点,则它对 \(w\) 的贡献为 \(c\times a_u\)。
然后能想到应该是树形 dp,但这个 dp 设计不是我能想出来的。
记 \(dp_{i, j, k}\) 表示考虑点 \(i\) 的子树内的点的贡献,\(i\) 能到达的点有 \(j\) 个(包括子树外的),其中 \(i\) 能到达其子树内的 \(k\) 个点。
初始值 \(dp_{u, i, 1} = a_u\)(此时还没有考虑 \(u\) 向外的贡献),答案 \(\max(f_{1, i, i})\)。以下 \(v\) 为 \(u\) 的子节点:
第一个转移表示 \(u \to v\),第二个转移表示 \(v \to u\),第三个转移是类似 代价提前计算 的思想,即虽然考虑在 \(u\) 的子树内,但当前把 \(u\) 连向外面的贡献也在此时计算了。因为如果你事后再统计贡献,\(u\) 就被压在下面了,并不好处理。注意第三个转移要在把 \(u\) 的所有子节点遍历完后再进行。
该做法为洛谷题解区的做法,时空复杂度均为 \(O(n^3)\)。(时间复杂度的计算形如树形背包)
一道 dp 题调了我很久,根本在于确实是没有完全理解上述 dp 的精髓,而这种 dp 方式也实在是不寻常。我们设计出的 dp 状态中,预先确定了当前点 \(u\) 最终能到达的点的个数,而 dp 转移的实质是在用子树 \(u\) 内的点去 填充 这些可以到达的点。在当前点 \(u\) 的转移 过程中,\(u\) 的 dp 值是 不完全 的,即没有实时算出连向子树外的贡献,而是在所有子节点 \(v\) 向 \(u\) 转移完后再把 \(u\) 的 dp 值更新正确。或者说,这是一种 阶段性 的 dp 转移。
如果你觉得我说了依托答辩,那确实是答辩,因为这个人调了 3h,精神有些失常。
#include<bits/stdc++.h>
#define LL long long
#define DB double
#define MOD 1000000007
#define ls(x) x << 1
#define rs(x) x << 1 | 1
#define lowbit(x) x & (-x)
#define PII pair<int, int>
#define MP make_pair
#define VI vector<int>
#define VII vector<int>::iterator
#define all(x) x.begin(), x.end()
#define EB emplace_back
#define SI set<int>
#define SII set<int>::iterator
#define QI queue<int>
using namespace std;
template<typename T> void chkmn(T &a, const T &b) { (a > b) && (a = b); }
template<typename T> void chkmx(T &a, const T &b) { (a < b) && (a = b); }
int inc(const int &a, const int &b) { return a + b >= MOD ? a + b - MOD : a + b; }
int dec(const int &a, const int &b) { return a - b < 0 ? a - b + MOD : a - b; }
int mul(const int &a, const int &b) { return 1LL * a * b % MOD; }
int sqr(const int &a) { return 1LL * a * a % MOD; }
void Inc(int &a, const int &b) { ((a += b) >= MOD) && (a -= MOD); }
void Dec(int &a, const int &b) { ((a -= b) < 0) && (a += MOD); }
void Mul(int &a, const int &b) { a = 1LL * a * b % MOD; }
void Sqr(int &a) { a = 1LL * a * a % MOD; }
int qwqmi(int x, int k = MOD - 2)
{
int res = 1;
while(k)
{
if(k & 1) Mul(res, x);
Sqr(x), k >>= 1;
}
return res;
}
const int N = 401;
const LL INF = 1e18;
int n, a[N];
vector<int> G[N]; int sz[N];
LL dp[N][N][N], tmp[N][N];
void dfs(int u, int fa)
{
sz[u] = 1;
for(int i = 0; i <= n; ++i)
for(int j = 0; j <= n; ++j)
dp[u][i][j] = -INF;
for(int i = 1; i <= n; ++i)
dp[u][i][1] = a[u];
for(auto v : G[u])
{
if(v == fa)
continue;
dfs(v, u);
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= sz[u] && j <= i; ++j)
{
tmp[i][j] = dp[u][i][j];
dp[u][i][j] = -INF; // !!!
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= sz[u] && j <= i; ++j)
for(int x = 1; x <= sz[v] && j + x <= i; ++x)
chkmx(dp[u][i][j + x], tmp[i][j] + dp[v][x][x] + 1LL * a[u] * x);
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= sz[u] && j <= i; ++j)
for(int x = 1; x <= sz[v] && i + x <= n; ++x)
chkmx(dp[u][i][j], tmp[i][j] + dp[v][i + x][x]);
sz[u] += sz[v];
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= sz[u] && j <= i; ++j)
dp[u][i][j] += 1LL * a[u] * (i - j);
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for(int i = 1; i < n; ++i)
{
int x, y;
scanf("%d %d", &x, &y);
G[x].EB(y), G[y].EB(x);
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= n; ++j)
tmp[i][j] = -INF;
dfs(1, 0);
LL ans = -INF;
for(int i = 1; i <= n; ++i)
chkmx(ans, dp[1][i][i]);
printf("%lld\n", ans);
return 0;
}