WQS 二分
\(\rm WQS\) 二分适用于一类带凸性性质的 \(\rm DP\) 方程的优化,可以快速的求出某一个状态的值。
引入:
- Problem: 有 \(n\) 个物品,恰好选出其中 \(m\) 个物品,最大化价值总和。
这个问题容易通过贪心解决。但是我们现在要考虑 \(\rm DP\) 的做法。若我们令 \(f_{i, j}\) 为考虑前 \(i\) 个物品,选了 \(m\) 个物品时最大价值总和。不难有 \(O(n^2)\) 做法。
接下来是\(\rm WQS\) 二分的优化方法,可以将算法时间复杂度优化到 \(O(n \log V)\)。
我们令 \(f_i\) 为恰好选了 \(i\) 个物品时的最大价值总和。
我们发现若没有 \(m\) 的限制,则该问题是简单的。
如果我们将所有的 \((i, g_i)\) 在平面坐标系上表示出来,不难发现它组成了一个上凸包的样子。此时我们拿一条斜率为 \(\rm slope\) 的直线去切这个凸包,若第一个切到的点为 \(p\),则其他的点都在 \(p\) 的下方。换句话说, \(p\) 点是凸包上的点中使该直线的截距最大的点。
于是我们进而考虑最大化该直线的截距,若我们取的点是 \((p, g_p)\),截距则为 \(c = g_p - slope \times p\)。此时有关键的一步:不妨将所有的物品价值都减掉 \(slope\)。则问题被转化为了没有 \(m\) 的限制的情况。于是不难求出 \(p\),以及对应的 \(g_p\)。
但是我们要求的是 \(m\) 的情况啊,那我们让直线切到的点是 \((m, g_m)\) 不就好了吗。由凸包的美好性质,不难发现当 \(p \le m\) 时应该减小斜率,否则增大斜率。那么这个东西具有单调性,用二分不难在 \((\log V)\) 时间解决。
例题:P5633 最小度限制生成树
显然答案具有下凸性,反证法不难证明。
于是我们二分斜率 \(slope\),让所有跟 \(s\) 相连的边都减掉 \(slope\),此时我们就把 \(k\) 的限制消掉了,做一遍 \(\rm Kruskal\) 即可。
注意有一个特殊的情况:当凸包上有三点共线的情况时,我们可能切不到 \(k\) 所对应的点,而是切到两端的点,于是我们要尽量让选的边数尽量小。
代码注意细节。
code:
qwq
#include<bits/stdc++.h>
#define int long long
#define eq(i) ((edges[i].u == s) || (edges[i].v == s))
using namespace std;
const int N = 5e5 + 10;
struct edge{
int u, v, w;
}edges[N];
int n, m, s, k;
int fa[N], gp, res, cnt, inf = 5e4;
int findfa(int u){return fa[u] = (fa[u] == u) ? u : findfa(fa[u]);}
bool cmp(struct edge e1, struct edge e2){
if(e1.w != e2.w) return e1.w < e2.w;
return (e1.u == s) || (e1.v == s);//注意此处我们让含有 s 的尽量排在前面
}
void calc(int slope){
for(int i = 1; i <= n; i++) fa[i] = i;
for(int i = 1; i <= m; i++) if(eq(i)) edges[i].w -= slope;
sort(edges + 1, edges + m + 1, cmp); cnt = gp = res = 0;
for(int i = 1; i <= m; i++){
int fu = findfa(edges[i].u), fv = findfa(edges[i].v);
if(fu == fv) continue;
cnt++; gp += eq(i); res += edges[i].w;
fa[fu] = fv;
}
for(int i = 1; i <= m; i++) if(eq(i)) edges[i].w += slope;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> m >> s >> k;
for(int i = 1; i <= m; i++){
int u, v, w; cin >> u >> v >> w;
edges[i] = {u, v, w};
}
calc(inf); if(k > gp){cout << "Impossible"; return 0;}
calc(-inf);if(k < gp){cout << "Impossible"; return 0;}
int l = -inf, r = inf;
while(l + 1 < r){
int mid = (l + r) >> 1; calc(mid);
if(gp < k) l = mid;
else r = mid;
}
calc(r); cout << res + r * k;
return 0;
}