P6748 Fallen Lord [树形DP]
Description
给定 \(n\) 个节点的树,每个点有点权 \(a_i\),求构造一组边权,使得每个点连接的边的边权的中位数不超过其点权,且每条边权不超过给定的 \(m\),输出边权之和的最大值。
一个升序序列 \(A=\{A_1,A_2,A_3...A_n\}\) 的中位数定义为 \(A_{\lfloor n/2\rfloor +1}\)。
\(n\le 5\times 10^5\),\(a_i\le m\le 10^9\)。
Solution
一个序列的中位数不超过 \(x\),相当于这个序列至少有 \(\lfloor n/2\rfloor+1\) 个元素不超过 \(x\)。
根据以上性质,可知一个边的边权有大于和不大于两种状态,且这个边权只和 两边的端点 以及 这些端点的其他边 有关,那么我们尝试去表达这种状态。
又根据题面,我们尝试考虑树形DP。
设 \(f(u,0/1)\) 表示以 \(u\) 为根的子树的答案,且边 \((u,fa_u)\) 的边权 不大于或大于 \(a_u\),
\(g(u,0/1)\) 表示以 \(u\) 为根的子树 加上 边 \((u,fa_u)\) 的答案,且边 \((u,fa_u)\) 的边权 不大于或大于 \(a_{fa_u}\),
至此,我们表示出了点与边的关系,然后考虑如何转移。
设 \(d_u\) 表示点 \(u\) 的度数,\(x=\lfloor d_u/2\rfloor+1\),\(v\in son_u\),
-
\(\sum g(v,0/1)\to f(u,0)\) :
- 由于 \((u,fa_u)\) 已经不超过 \(a_u\) 了,那么还需要 \(x-1\) 条边不超过 \(a_u\) 才能满足条件,也就是最多可以有 \(d_u-x\) 条边大于 \(a_u\),那么我们只需要从所有 \(g(u,0/1)\) 中选出一个最优的满足这个条件的方案即可。考虑按照 \(g(v,1)-g(v,0)\) 降序排序,贪心的选前面不超过 \(d_u-x\) 个且差值大于 \(0\) 的 \(g(v,1)\) 转移到 \(f(u,0)\),后面的全部选 \(g(v,0)\) 即可。
-
\(\sum g(v,0/1)\to f(u,1)\) 类似。
-
对于 \(g(u,0/1)\) 我们分类讨论 :
- \(a_u > a_{fa_u}\) :
- \(g(u,0)\) :\(a_u>a_{fa_u}\ge (u,fa_u)\),\(\to\ = f(u,0)+a_{fa_u}\)。
- \(g(u,1)\) :\(a_u\ge(u,fa_u)>{fa_u}\) 或 \(m\ge (u,fa_u) > a_u > a_{fa_u}\),\(\to\ =\max\{f(u,0)+a_u,\ f(u,1)+m\}\)。
- \(a_u\le a_{fa_u}\) :
- \(g(u,0)\) :\((u,fa_u)\le a_u\le a_{fa_u}\) 或 \(a_u < (u,fa_u)\le a_{fa_u}\),\(\to\ =\max\{f(u,0)+a_u,\ f(u,1)+a_{fa_u}\}\)。
- \(g(u,1)\) :\(a_u\le a_{fa_u}<(u,fa_u)\le m\),\(\to\ =f(u,1)+m\)。
- \(a_u > a_{fa_u}\) :
-
特殊情况:当 \(d_u\le 2\) 时,显然 \(f(u,1)\) 无法转移,设为 \(-\inf\) 即可。
-
答案:\(f(1,0)\)。
于是,这道题就做完了。
Code
const int N = 5e5 + 5;
const ll inf = 1e18;
int n;
ll m, a[N];
int deg[N];
vector <int> e[N];
void adde(int u, int v) {e[u].ps(v);}
ll f[N][2], g[N][2];
void dfs(int u, int fa){
vector <pair<ll, int> > tmp;
for(int v : e[u]){
if(v == fa) continue;
dfs(v, u);
tmp.ps(mk(g[v][1] - g[v][0], v));
}
sort(tmp.begin(), tmp.end(), greater <pair<ll, int> > ());
int k = deg[u] - deg[u] / 2 - 1;
for(int i = 0; i < tmp.size(); i++){
auto [w, v] = tmp[i];
if(i < k && w > 0) f[u][0] += g[v][1];
else f[u][0] += g[v][0];
if(i < k - 1 && w > 0) f[u][1] += g[v][1];
else f[u][1] += g[v][0];
}
if(deg[u] <= 2) f[u][1] = -inf;
if(a[u] > a[fa]){
g[u][0] = f[u][0] + a[fa];
g[u][1] = max(f[u][0] + a[u], f[u][1] + m);
}else{
g[u][0] = max(f[u][0] + a[u], f[u][1] + a[fa]);
g[u][1] = f[u][1] + m;
}
}
void Solve(){
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
int u, v;
for(int i = 1; i <= n - 1; i++){
cin >> u >> v;
deg[u]++;
deg[v]++;
adde(u, v);
adde(v, u);
}
dfs(1, 0);
cout << f[1][0] << endl;
}