WQS 二分学习笔记
1. 股票买卖问题
1.1 1.0 版本
考虑现在有 \(n\) 天,每天的股票价格 \(a_i\) 已知。你手上同时只能持有至多一张股票,且一笔买卖需要支付 \(c\) 的手续费。求最大收益。
1.1.1 解法 1:DP
我们不妨设 \(f(i,0/1)\) 表示前 \(i\) 天结束后手上是否持有股票。转移非常简单:
可以在 \(O(n)\) 内解决这个问题。
1.1.2 解法 2:贪心
不妨考虑直接贪心选择最优解。
-
如果上次操作是买入 \(q\),当前是 \(p\),则:如果 \(p < q\) 将 \(q\) 替换成 \(p\)。如果 \(p > q + c\) 则直接卖出 \(p\)。
-
如果上次操作是卖出 \(q\),当前是 \(p\),则:如果 \(p > q\) 则将 \(q\) 替换成 \(p\)。如果 \(p < q - c\) 则现在买入 \(p\)。
这个贪心可以得到最优解,时间复杂度同样是 \(O(n)\)。
1.1.3 解法 3:费用流
显然可以构图,变成求最大费用流,但是时间复杂度不好。
1.2 2.0 版本。
不妨考虑现在多了一个要求:至多执行 \(k\) 次买入,但是没有手续费。
如果还是 DP,时间复杂度变为 \(O(nk)\),我们希望更快。
1.2.1 凸性
不妨设 \(g(k)\) 表示恰好 \(k\) 次买入的最大收益,我们可以证明,\(g(k)\) 是一个上凸函数(可以理解为上凸壳),也就是 \(g(k) - g(k-1)\) 的值不增。
正确性可以根据这是一个费用流模型,而费用流每次找的增广路的长度是递减的,所以得证。很多凸性都可以通过费用流来证明。
1.2.2 凸包上二分。
现在我们只需要求出 \(k\) 对应的 \(g(k)\) 即可。
我们将 \((k,g(k))\) 视作一个点,由于其是凸包,我们可以想到用直线取切。
不妨设当前的斜率是 \(c\),则我们就是要找到 \(g(k) - k \times c\) 最大的点。
考虑实际意义,相当于 1.0 版本每一笔交易需要 \(c\) 的手续费!所有这个问题我们可以在 \(O(n)\) 内解决。
所以我们可以求出 \(k\),如果 \(k\) 小了,我们就让斜率变小,否则让斜率变大,通过二分找到我们想要的 \(k\),问题就能在 \(O(n \log n)\) 内完成。
1.2.3 细节
首先我们要考虑到可能存在三个点及以上个点共线的情况,为此,建议使用 DP,将 \((ans,mn,mx)\) 表示当前答案为 \(ans\),最少选 \(mn\) 个,最多选 \(mx\) 个。然后重载运算符就可以正常 \(DP\) 了。
所以我们每次二分得到的返回值是一个区间,如果在当前区间就直接结束,否则注意是按照斜率来二分。
这就是 WQS 二分,通过二分凸函数的斜率来求出答案。
下面给出 P6821 [PA2012] Tanie linie 这道题转化成前缀和后就是一个股票买卖问题:
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e6 + 5;
const long long inf = 1e9;
int n, k;
long long a[N] = {0};
struct Node {
long long x;
int mn, mx;
Node (long long _x = -inf, int _mn = 2e9, int _mx = 0) :
x(_x), mn(_mn), mx(_mx) {}
} f[N][2];
Node operator+(Node u, Node v) {
return Node(u.x + v.x, u.mn + v.mn, u.mx + v.mx);
}
Node operator*(Node u, Node v) {
if (u.x > v.x)
return u;
if (u.x < v.x)
return v;
return Node(u.x, min(u.mn, v.mn), max(u.mx, v.mx));
}
Node cal(long long c) {
f[0][0] = Node(0ll, 0, 0);
f[0][1] = Node();
for (int i = 1; i <= n; i++) {
f[i][0] = (f[i - 1][1] + Node(a[i] - c, 1, 1)) * f[i - 1][0];
f[i][1] = (f[i - 1][0] + Node(-a[i], 0, 0)) * f[i - 1][1];
}
return f[n][0];
}
void read(int &x) {
int D = 1;
x = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
D = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
x = x * 10 + ch - '0', ch = getchar();
x = x * D;
}
void readll(long long &x) {
int D = 1;
x = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
D = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
x = x * 10 + ch - '0', ch = getchar();
x = x * D;
}
int main() {
read(n), read(k);
n++;
for (int i = 2; i <= n; i++)
readll(a[i]), a[i] += a[i - 1];
long long l = -inf, r = inf;
while (l + 1 < r) {
long long mid = (l + r) / 2;
Node t = cal(mid);
if (t.mn <= k && k <= t.mx) {
l = mid;
break;
}
else if (t.mx < k)
r = mid;
else
l = mid + 1;
}
Node t = cal(l);
long long ans = t.x + 1ll * k * l;
Node mx = cal(0);
if (mx.mx <= k)
ans = max(mx.x, ans);
cout << ans << endl;
return 0;
}