wqs二分
wqs二分
用来处理一类带有限制的问题,如恰好选 \(k\) 个,本质是通过二分来规避这个选取数量的限制。
使用前提:原问题具有凹凸性。设 \(g_i\) 表示选 \(i\) 个物品的答案,那么所有 \((i, g_i)\) 点组成一个凸包,满足 \(g'(k)\) 单调。
这类题目通常有以下特点:
- 如果不限制选的个数,那么很容易求出最优方案。
- 权值随选的物品增加而单调。
实现
先不考虑恰好选 \(k\) 个的限制,则 DP 的复杂度会降低。
考虑二分一个 \(mid\) ,表示一次选取附加的权值。则选取的次数越多,附加权值越大,选的就会越多/越少,根据选的数量来调整 \(mid\) ,最后调整到恰好选 \(k\) 个时减掉附加权值即为答案。
本质就是二分斜率使得凸包切线切点在 \(x = k\) 处,检查函数返回的是截距。
细节
如果斜率为整数时,答案的切线会同时切多个点,此时就会出现 \(mid\) 时切到的 \(x\) 坐标小于 \(k\) ,而 \(mid + 1\) 时切到的 \(x\) 坐标大于 \(k\) 。
注意到一个点对应的切线斜率是一段区间,而且我们只要保证这个斜率会切到 \(k\) 即可。由于 \(g(k + 1) - g(k)\) 为整数,所以这个斜率一定存在。
实现时每次判定时能选就选,这样就会切到最右边的点,此时为了让斜线的斜率逼近 \(g'(k)\) ,于是就要根据凸壳形状减小(下凸壳)或增大(上凸壳)切线斜率,并更新答案。
构造解
对于 wqs 二分能处理的大部分问题,其 DP 过程中的答案仍然是凸的。
如对于序列分段问题,DP 的就是每一个前缀的切点,每个切点的坐标是由前面的切点转移过来的。
所以对于每一个 DP 的子问题,我们都记录切线所切的区间。这样我们可以从最后开始,每次尝试找到上一个可行的转移点满足 \(k\) 要落在其切点的区间内。
不妨用 \([l_i, r_i]\) 记录第 \(i\) 个凸包的切线区间,那么我们只需要找到 \(k \in [l_j, r_j]\) 并且可以转移当前位置的 \(j\) 即可。
应用
求 \(deg_s = k\) 的情况下的最小生成树。
\(n \leq 5 \times 10^4, m \leq 5 \times 10^5, k \leq 100\)
二分一个 \(mid\) ,与 \(s\) 连接的边的边权都加上这个 \(mid\) ,则最终MST中与 \(s\) 连接的边的数量受 \(mid\) 影响,于是可以调整使得恰好 \(deg_s =k\) 。
时间复杂度 \(O(m \log m \log V)\) 。
P2619 [国家集训队] Tree I 做法也是类似的。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e4 + 7, M = 5e5 + 7;
struct Edge {
int u, v, w, c;
inline bool operator < (const Edge &rhs) const {
return w == rhs.w ? c > rhs.c : w < rhs.w;
}
} e[M];
struct DSU {
int fa[N];
inline void prework(int n) {
iota(fa +1 , fa + 1 + n, 1);
}
inline int find(int x) {
while (x != fa[x])
fa[x] = fa[fa[x]], x = fa[x];
return x;
}
inline void merge(int x, int y) {
fa[find(y)] = find(x);
}
} dsu;
int n, m, s, k, tot;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline pair<ll, int> Kruskal() {
sort(e + 1, e + 1 + m);
dsu.prework(n);
pair<ll, int> ans = make_pair(0ll, 0);
for (int i = 1; i <= m; ++i) {
if (dsu.find(e[i].u) == dsu.find(e[i].v))
continue;
dsu.merge(e[i].u, e[i].v);
ans.first += e[i].w, ans.second += e[i].c;
}
return ans;
}
inline pair<ll, int> check(int lambda) {
for (int i = 1; i <= m; ++i)
if (e[i].c)
e[i].w += lambda;
pair<ll, int> ans = Kruskal();
for (int i = 1; i <= m; ++i)
if (e[i].c)
e[i].w -= lambda;
return ans;
}
signed main() {
n = read(), m = read(), s = read(), k = read();
dsu.prework(n);
for (int i = 1; i <= m; ++i) {
e[i].u = read(), e[i].v = read(), e[i].w = read(), e[i].c = (e[i].u == s || e[i].v == s);
dsu.merge(e[i].u, e[i].v);
}
for (int i = 2; i <= n; ++i)
if (dsu.find(i) != dsu.find(1))
return puts("Impossible"), 0;
if (check(-inf).second < k || check(inf).second > k)
return puts("Impossible"), 0;
int l = -inf, r = inf, ans = 0;
while (l <= r) {
int mid = (l + r) >> 1;
if (check(mid).second >= k)
l = mid + 1, ans = mid;
else
r = mid - 1;
}
printf("%lld\n", check(ans).first - 1ll * k * ans);
return 0;
}
定义一段序列 \(a_{l \sim r}\) 的价值为:
\[((\sum_{i = l}^r a_i) + 1)^2 \]将 \(a_{1 \sim n}\) 分为 \(m\) 段,求每段价值和的最小值。
\(n \leq 10^5\)
首先由于 \((a + b)^2 \geq a^2 + b^2\) ,所以要分的段数要越多越好,答案关于 \(m\) 增加而减小,且因为先选减小值更小的地方分开更优,具有凸性。
于是可以 wqs 二分,判定部分可以斜率优化,时间复杂度 \(O(n \log V)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;
ll a[N], s[N], f[N];
int q[N], g[N];
int n, m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline double slope(int i, int j) {
return ((f[i] + s[i] * s[i]) - (f[j] + s[j] * s[j])) / (s[i] - s[j]);
}
inline pair<ll, int> check(ll mid) {
for (int i = 1, head = 0, tail = 0; i <= n; ++i) {
while (head < tail && slope(q[head], q[head + 1]) < 2.0 * (s[i] + 1))
++head;
f[i] = f[q[head]] + (s[i] - s[q[head]] + 1) * (s[i] - s[q[head]] + 1) + mid;
g[i] = g[q[head]] + 1;
while (head < tail && slope(q[tail - 1], q[tail]) > slope(q[tail], i))
--tail;
q[++tail] = i;
}
return make_pair(f[n], g[n]);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i)
s[i] = s[i - 1] + (a[i] = read());
ll l = 0, r = 1e18, ans = 0;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid).second <= m)
r = mid - 1, ans = mid;
else
l = mid + 1;
}
printf("%lld", check(ans).first - ans * m);
return 0;
}
CF802O April Fools' Problem (hard)
\(n\) 道题,第 \(i\) 天可以花费 \(a_i\) 准备一道题或花费 \(b_i\) 打印一道题或什么也不做,准备的题可以留到以后打印,求最少花费使得准备并打印 \(k\) 道题。
\(n \leq 5 \times 10^5\)
wqs二分配合反悔贪心即可,时间复杂度 \(O(n \log n \log V)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;
ll a[N], b[N];
int n, k;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline pair<ll, int> check(ll lambda) {
ll ans = 0;
int cnt = 0;
priority_queue<pair<ll, int> > q;
for (int i = 1; i <= n; ++i)
b[i] -= lambda;
for (int i = 1; i <= n; ++i) {
q.emplace(make_pair(-a[i], 1));
if (-b[i] > -q.top().first) {
ans += b[i] - q.top().first, cnt += q.top().second;
q.pop(), q.emplace(make_pair(b[i], 0));
}
}
for (int i = 1; i <= n; ++i)
b[i] += lambda;
return make_pair(ans, cnt);
}
signed main() {
n = read(), k = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
for (int i = 1; i <= n; ++i)
b[i] = read();
ll l = -1e18, r = 1e18, ans = 1e18;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid).second >= k)
ans = mid, r = mid - 1;
else
l = mid + 1;
}
printf("%lld", check(ans).first + ans * k);
return 0;
}
现在一共有 \(n\) 只神奇宝贝,你有 \(a\) 个宝贝球和 \(b\) 个超级球,宝贝球抓到第 \(i\) 只神奇宝贝的概率是 \(p_i\),超级球抓到的概率则是 \(u_i\)。
不能往同一只神奇宝贝上使用超过一个同种的球,但是可以往同一只上既使用宝贝球又使用超级球(都抓到算一个)。
求合理分配下抓到神奇宝贝的总个数期望的最大值。
\(n \leq 2000\)
首先球全部用完一定最优。考虑 wqs 二分斜率 \(\lambda\) ,每用一个超级球答案就减去 \(\lambda\) ,DP 出超级球选 \(b\) 个的方案。
直接 DP 是 \(O(n^2 \log V)\) 的,但是可以进一步优化,考虑每个位置的四种选择:
- 用宝贝球和超级球:\(p + u - pu - \lambda\) 。
- 用宝贝球:\(p\) 。
- 用超级球: \(u - \lambda\) 。
- 不用球:\(0\) 。
考虑一个位置从不用宝贝球到用宝贝球,答案增加量就是:
对这个增加量排序后贪心即可,时间复杂度 \(O(n \log n \log V)\) 。
#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-9;
const int N = 2e3 + 7;
struct Node {
double x;
int kb, ka;
inline bool operator < (const Node &rhs) const {
return x > rhs.x;
}
} nd[N];
double p[N], u[N];
int n, a, b;
inline pair<double, int> check(double lambda) {
double ans = 0;
int cnt = 0;
for (int i = 1; i <= n; ++i) {
if (p[i] + u[i] - p[i] * u[i] - lambda > p[i])
nd[i].ka = 1, nd[i].x = p[i] + u[i] - p[i] * u[i] - lambda;
else
nd[i].ka = 0, nd[i].x = p[i];
if (u[i] - lambda > 0)
nd[i].kb = 1, nd[i].x -= u[i] - lambda, ans += u[i] - lambda;
else
nd[i].kb = 0;
}
sort(nd + 1, nd + 1 + n);
for (int i = 1; i <= a; ++i)
ans += nd[i].x, cnt += nd[i].ka;
for (int i = a + 1; i <= n; ++i)
cnt += nd[i].kb;
return make_pair(ans, cnt);
}
signed main() {
scanf("%d%d%d", &n, &a, &b);
for (int i = 1; i <= n; ++i)
scanf("%lf", p + i);
for (int i = 1; i <= n; ++i)
scanf("%lf", u + i);
double l = 0, r = 1, ans = 1;
while (r - l > eps) {
double mid = (l + r) / 2;
if (check(mid).second >= b)
l = mid;
else
ans = r = mid;
}
printf("%.6lf", check(ans).first + ans * b);
return 0;
}
给定一棵 \(n\) 个点的树,在树上选出 \(k + 1\) 条互不相交链,最大化其权值之和。
\(n \leq 3 \times 10^5\)
根据 wqs 二分的经典套路,下面考虑没有 \(k\) 的限制怎么做。
设 \(f_{u, 0/1/2}\) 表示考虑到 \(u\) 且 \(deg_u = 0/1/2\) 时的答案,更进一步的:
- \(deg = 0\) :这个点没有连边。
- \(deg = 1\) :这个点连着一条未完成的链,该链还未计入答案。
- \(deg = 2\) :这个点连着一条连接两个不同子树的链。
首先约定在每个节点的全部转移结束时,进行一次更新:
这样就把 \(u\) 的全部最优解统计了出来。对于 \(deg = 0/2\) 的情况可以直接合并,而对于 \(deg = 1\) 的情况要先把这条链结束掉并统计答案。
则有转移方程:
第一行表示 \(u\) 不接到 \(v\) 上,直接继承 \(v\) 的最优解。
第二行表示 \(u\) 接到 \(v\) 上,继承 \(v\) 一条未完成的链,得到一条完成的链。
第一行表示 \(u\) 不接到 \(v\) 上,直接继承 \(v\) 的最优解。
第二行表示 \(u\) 接到 \(v\) 上,继承 \(v\) 一条未完成的链,得到一条未完成的链
这里 \(u\) 必须不接 \(v\) ,只能取 \(v\) 的最优解。
时间复杂度 \(O(n \log V)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 3e5 + 7;
struct Graph {
vector<pair<int, int> > e[N];
inline void insert(int u, int v, int w) {
e[u].emplace_back(v, w);
}
} G;
pair<ll, int> f[N][3];
int n, k;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline pair<ll, int> operator + (const pair<ll, int> &a, const pair<ll, int> &b) {
return make_pair(a.first + b.first, a.second + b.second);
}
inline pair<ll, int> operator + (const pair<ll, int> &a, const ll &b) {
return make_pair(a.first + b, a.second);
}
void dfs(int u, int fa, const pair<ll, int> lambda) {
f[u][0] = f[u][1] = f[u][2] = make_pair(0, 0);
f[u][2] = max(f[u][2], lambda);
for (auto it : G.e[u]) {
int v = it.first, w = it.second;
if (v == fa)
continue;
dfs(v, u, lambda);
f[u][2] = max(f[u][2], max(f[u][2] + f[v][0], f[u][1] + w + f[v][1] + lambda));
f[u][1] = max(f[u][1], max(f[u][1] + f[v][0], f[u][0] + w + f[v][1]));
f[u][0] = max(f[u][0], f[u][0] + f[v][0]);
}
f[u][0] = max(f[u][0], max(f[u][1] + lambda, f[u][2]));
}
inline pair<ll, int> check(ll lambda) {
dfs(1, 0, make_pair(lambda, 1));
return f[1][0];
}
signed main() {
n = read(), k = read() + 1;
for (int i = 1; i < n; ++i) {
int u = read(), v = read(), w = read();
G.insert(u, v, w), G.insert(v, u, w);
}
ll l = -1e12, r = 1e12, ans = 0;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid).second >= k)
ans = mid, r = mid - 1;
else
l = mid + 1;
}
printf("%lld", check(ans).first - ans * k);
return 0;
}
有 \(n\) 个村庄,放 \(m\) 个邮局,求每个村庄到最近邮局的距离和的最小值。
\(m \leq n \leq 5 \times 10^5\)
设 \(f_{i, j}\) 表示前 \(i\) 个村庄放 \(j\) 个邮局的最小距离和,\(w(l, r)\) 表示在 \([l, r]\) 范围村庄放一个邮局的最小距离和,则有:
决策单调性优化做到 \(O(n^2)\) 。
考虑用 wqs 二分规避 \(j\) 的限制,于是得到一个1D/1D 的 DP,并且也有决策单调性,可以二分栈做到 \(O(n \log n \log V)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;
struct Node {
int j, l, r;
} q[N];
ll s[N], f[N];
int a[N], g[N];
int n, m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline ll w(int l, int r) {
int mid = (l + r) >> 1;
return (s[r] - s[mid]) - 1ll * a[mid] * (r - mid) + 1ll * a[mid] * (mid - l + 1) - (s[mid] - s[l - 1]);
}
inline ll calc(int i, int j) {
return f[j] + w(j + 1, i);
}
inline int BinarySearch(int l, int r, int i, int j) {
int ans = r + 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (calc(mid, i) <= calc(mid, j))
ans = mid, r = mid - 1;
else
l = mid + 1;
}
return ans;
}
inline pair<ll, int> check(ll lambda) {
int head = 1, tail = 0;
q[++tail] = (Node) {0, 1, n};
for (int i = 1; i <= n; ++i) {
if (q[head].r == i - 1)
++head;
f[i] = calc(i, q[head].j) + lambda, g[i] = g[q[head].j] + 1;
int pos = n + 1;
while (head <= tail) {
if (calc(q[tail].l, i) <= calc(q[tail].l, q[tail].j))
pos = q[tail--].l;
else {
pos = BinarySearch(q[tail].l, q[tail].r, i, q[tail].j);
q[tail].r = pos - 1;
break;
}
}
if (pos != n + 1)
q[++tail] = (Node) {i, pos, n};
}
return make_pair(f[n], g[n]);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
sort(a + 1, a + 1 + n);
for (int i = 1; i <= n; ++i)
s[i] = s[i - 1] + a[i];
ll l = 0, r = 1e12, ans = 0;
while (l <= r) {
ll mid = (l + r) >> 1;
if (check(mid).second >= m)
ans = mid, l = mid + 1;
else
r = mid - 1;
}
printf("%lld", check(ans).first - ans * m);
return 0;
}