「学习笔记」wqs二分/dp凸优化
【学习笔记】wqs二分/DP凸优化
## 从一个经典问题谈起:
有一个长度为 \(n\) 的序列 \(a\),要求找出恰好 \(k\) 个不相交的连续子序列,使得这 \(k\) 个序列的和最大
\(1 \leq k \leq n \leq 10^5, -10^9 \leq a_i \leq 10^9\)
先假装都会 \(1 \leq k \leq n \leq 1000\) 的 \(dp\) 做法以及 \(k = 1\) 的子问题
实际上这个问题还可以是个费用流模型:
对于序列中每一个点 \(i\) ,拆成两个点 \(i\) 和 \(i'\) ,连一条 \(i \rightarrow i'\) 流量为 \(1\) 费用为 \(a_i\) 的边
对于每一个 \(i\) ,连一条 \(S \rightarrow i\) 流量为 \(1\) 费用为 \(0\) 的边
对于每一个 \(i'\) ,连一条 \(i' \rightarrow T\) 流量为 \(1\) 费用为 \(0\) 的边
对于相邻的两个点 \(i\) 和 \(i + 1\) ,连一条 \(i'\) 到 \(i+1\) 流量为 \(1\) 费用为 \(0\) 的边
显然每次沿着最大费用路径单路增广一次的话就是选择了原问题的一个最大连续子序列
实际上这样增广 \(k\) 次后的结果就是答案,因为有反向边的存在所以选出来的区间不会相交
这个做法的复杂度其实并没有直接 \(dp\) 优,但是可以基于这个模型进行很多优化
线段树优化:\(\text{codeforces 280 D. k-Maximum Subsequence Sum}\)
把模型放到原问题上,每一次单路增广相当于是求全局的最大连续子段和然后将其取反
直接用线段树维护这两个操作,复杂度优化到 \(O(klogn)\)
数据结构的优化在这个问题上还算适用,但是对于问题的模型有一定的局限性
显然上述做法不是本文的重点,不妨继续考虑这个费用流做法
观察发现由于每次单路增广的是最长路,增广后的网络是之前网络的残余网络,所以每一次增广得到的费用都会比上一次得到的要少。
也就是说,如果设 \(f(x)\) 为增广 \(x\) 以后的总流量,\(f(x)\) 的函数图像是一个上凸包
实际上 \(f(x)\) 等价于选取了 \(x\) 个不相交的连续子序列的最大和,也就是原问题。
考虑除了用数据结构进行繁琐的维护以外,我们并没有什么办法高效的直接求出 \(f(x)\) 的每一项
但设 \(\max(f(x))\) 的值可以通过 \(O(n)\) 就可以求出,在这个问题里就是把所有 \(> 0\) 的数加起来
也就是说,我们可以简单的求出这个函数的极点的值,这启发我们可以通过对函数进行魔改,使得极点在 \(k\) 上
由于函数是上凸的,不妨设 \(f'(x) = f(x) + px\) ,显然当 \(p\) 的值增加时,极点的位置会左移
那么问题就转化为找到一个合适的斜率 \(p\) 使得 \(f'(x) = f(x) + px\) 的最大值在 \(x =k\) 时取到
也就是拿一条斜率为 \(p\) 的直线去且这个凸包使得切点恰好在 \(x = k\) 上,由于凸包的性质切线的斜率是单调的
那么不妨二分斜率 \(p\),对于 \(f'(x) = f(x) + px\) 的取值加以验证,而把这个函数放回到原问题上,就是每选一个区间需要 \(p\) 的额外代价
于是就可以 \(dp\) 出在数量不限,每选一个区间要 \(p\) 的额外代价的情况下,能获得的最优总代价是什么,最优解选了多少个区间
这对应的是 \(f'(x)\) 的最大值以及取到最大值的 \(x\) ,根据这个可以判断出接下来斜率该增大还是减小
如果某一时刻得到最大值取到的位置为 \(x = k\) ,那么原问题的答案就是 \(f'(x) - px\) ,转化回去即可
此外还需要考虑一个细节,这个所谓的上凸包其有些点的取值并非在凸包的顶点上而是在凸包的边上,这样的话直线只能切到这条边而不能切到点了
但是考虑此时的顶点是可以切到的,所以只需要在 \(dp\) 的时候记录最优解取到的最左/最右位置即可,最后同样能得到正确的斜率 \(p\) ,此时这条边上的取值是相同的
至此就用一个二分和一个不带限制的 \(dp\) 以 \(O(nlogk)\) 的限制解决了此题,事实上但凡答案的形态是凸的题目都可以尝试用这种方法解决,相较于数据结构有很大的优势
## 一个例题(为了贴代码):
「九省联考 2018」林克卡特树
在树上选取 \(k + 1\) 条点不相交的链,使得选取的边权和最大
类似的,问题可以转化为每次选树的直径,然后给树的直径取反,这样的话函数的上凸性就显然了
当然这里也可以用数据结构来维护这个模型,(LCT 优化费用流),不过实在太毒瘤了想必也没什么人会写吧
相反,wqs二分/dp凸优化(其实是一个东西)的做法在这里就十分清真
类似的,二分斜率以后等价于每选一条链要花费 \(p\) 的额外代价,然后进行简单的树 \(dp\) 就可以了
\(f[u][0]\) 表示 \(u\) 子树内 \(u\) 不是任意一条所选链上的点,能获得的最大收益
\(f[u][1]\) 表示 \(u\) 子树内 \(u\) 是一条所选链的端点,能获得的最大收益
\(f[u][2]\) 表示 \(u\) 子树内 \(u\) 是一条所选链的 \(lca\),能获得的最大收益 (转移请自行推导)
总复杂度 \(O(nlogk)\)
/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf ((ll)(0x7f7f7f7f))
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
#define int ll
const int N = 700005;
int slope;
int a[N], b[N], nxt[N], head[N], cnt, n, k;
struct Node{
int ans; int cnt;
Node operator + (const int &A) const{ return (Node){ans + A, cnt}; }
Node operator + (const Node &A) const{ return (Node){ans + A.ans, cnt + A.cnt}; }
bool operator > (const Node &A) const{ return ans == A.ans ? cnt > A.cnt : ans > A.ans; }
}dp[N][3];
inline void addedge(int x, int y, int z){
a[++cnt] = y, b[cnt] = z, nxt[cnt] = head[x], head[x] = cnt;
}
inline void chkmax(Node &x, Node y){ if(y > x) x = y;}
inline void solve(int u, int fa){
Node Add = {-slope, 1};
for(int p = head[u]; p; p = nxt[p]) if(a[p] != fa){
int v = a[p], w = b[p]; solve(v, u);
chkmax(dp[u][2], Max(dp[u][2] + dp[v][0], dp[u][1] + dp[v][1] + w + Add));
chkmax(dp[u][1], Max(dp[u][1] + dp[v][0], dp[u][0] + dp[v][1] + w));
dp[u][0] = dp[u][0] + dp[v][0];
}
chkmax(dp[u][0], Max(dp[u][2], dp[u][1] + Add));
}
inline bool check(int mid){
slope = mid, memset(dp, 0, sizeof(dp));
solve(1, 0);
if(dp[1][0].cnt == k){
printf("%lld", (ll)(dp[1][0].ans + k * mid)); exit(0);
}
return dp[1][0].cnt > k;
}
signed main(){
read(n), read(k), k = Min(k + 1, n);
for(int i = 1, x, y, z; i < n; i++){
read(x), read(y), read(z);
addedge(x, y, z), addedge(y, x, z);
}
int l = (ll) -1e12, r = (ll) 1e12, ls;
while(l <= r){
int mid = (l + r) / 2;
if(check(mid)) l = mid + 1, ls = mid; else r = mid - 1;
}
check(ls);
printf("%lld\n", dp[1][0].ans + k * ls);
return 0;
}
## 一些参考资料和补充: