wqs 二分学习笔记
又称为带权二分
一种优化凸函数 dp 的方式,明显的标志是选 k 个。
一般这种玩意都是可以强套一个 wqs 二分上去,消一个 O(n) 加一个 \(O(\log)\),而且还是从状态数上消一个。
我们从 LCT 这道题来引入。
首先题目要求选 k+1 条不相交链的权值和最大。
设出 \(dp[i][j][0/1/2]\) 表示以 \(i\) 为根的子树在图上的度数为 \(0,1,2\),他的子树中含有 \(j\) 条链。并且当度数为 \(1\) 的时候这条链不计入 \(j\) 中。
分类讨论的有点繁琐就不写了,看看代码吧(讲真,这个 dp 还挺神仙的
我觉得 xtq 树形 dp 的方式还挺用的
inline void dfs(int x,int ff)
{
dp[0][x][0] = dp[1][x][0] = dp[2][x][1] = dp[3][x][0] = 0;
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
dfs(v , x);
// for(int u = 0;u <= 2;u++)
// for(int j=0;j <= k;j++) aux[u][j] = -INF;
memset(aux[0],0xcf,sizeof(aux[0])),memset(aux[1],0xcf,sizeof(aux[1]));
memset(aux[2],0xcf,sizeof(aux[2]));
for(int u = 0;u <= k;u++)
for(int q = 0;q + u <= k;q ++)
aux[0][u + q] = max(aux[0][u + q],dp[0][x][u] + dp[3][v][q]);
for(int u = 0;u <= k;u ++)
for(int q = 0;q + u <= k;q ++)
aux[1][u + q] = max(aux[1][u + q],max(dp[1][x][u] + dp[3][v][q],dp[0][x][u] + dp[1][v][q] + Edge[i]));
for(int u = 0;u <= k ; u++)
for(int q = 0;q + u<= k;q ++)
{
aux[2][u + q] = max(aux[2][u + q],dp[2][x][u] + dp[3][v][q]);
// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
}
for(int u = 0;u <= k ; u++)
for(int q = 0;q + u + 1<= k;q ++)
{
// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[2][x][u] + dp[3][v][q]);
aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
}
// printf("%d : \n",x);
for(int u =0 ;u <=2;u++)
for(int q = 0;q <= k;q ++ )
{
// printf("aux[%d][%d] = %d\n",u,q,aux[u][q]);
dp[u][x][q] = aux[u][q];
}
dp[1][x][0] = max(dp[1][x][0] , dp[1][v][0] + Edge[i]);
}
// dp[0][x][1] = max(0,dp[0][x][1]);
for(int i=1;i<=k;i++)
dp[3][x][i] = max(dp[0][x][i],max(dp[1][x][i - 1],dp[2][x][i]));
// for(int i=1;i<=)
}
我们现在的 \(dp\) 是 \(O(nk^2)\) 的。这个复杂度大的离谱。
这时候请出我们的带权二分来。这里默认我们的 dp 函数是一个凸函数
我们将原来的 \(dp\) 的 \(j\) 那维限制去掉,这样就可以将复杂度降到 \(O(n)\)。但是这样不能保证我们恰好选了 k+1 条链,所以我们要对原函数做一些魔改。
设原函数为 \(ans(x)\),当 \(ans'(x)=0\) 时,ans(x) 取得最大值。我们不加修改的 dp 求出来的就是这个东西。
现在我们设一个新函数 \(g(x) = ans(x) +val\times x\),这个函数一阶导为减函数,二阶导为一个上凸函数。所以我们可以通过调节 val (就是斜率)来调节 \(g'(x)\) 的零点,这样就能调节出当 \(g(x)\) 取得最值得时候\(ans(x)\) 恰好取得 \(k+1\) 条链,这样皆大欢喜。
关于恰好选 k 个是一个凸函数,你要想如果恰好选 \(1\) 个,那我们肯定选择最大的那个,选两个我会把次大的选上,这样每次的增量都不如上一个大,就会形成一个凸函数。更仔细一点,如果有那么一点点限制,要求恰好选 \(k\) 个,那我后面选的东西可能会影响到前面选的,并且这时候还要求数量达到我们要求的,就被迫舍弃权值最优,来追求数量,这就导致了凸函数的后半段产生。
复杂度 \(O(b\log k)\)
关于 wqs 的实际操作来说,有一点点细节。
对最大值来说:你考虑二分一个惩罚值,当你选的少了的时候,我们想让它下次选得再多一点,就会把惩罚值下调,反之就会上调。
对最小值来说:选得少了的时候,我们想让它下次再多选一点,惩罚值就会下调,反之上调。
对于多点共线的情况,我们优先选物品最少的或者最多的,二分的时候只要物品在 k 的我们指定的一侧时就去更新答案。
P4383 [八省联考2018]林克卡特树
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pii pair<int,int>
template<typename _T>
inline void read(_T &x)
{
x=0;char s=getchar();int f=1;
while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
x*=f;
}
const int np = 3e5 + 5;
int head[np],ver[np * 2],nxt[np * 2],Edge[np * 2];
int tit;
inline void add(int x,int y,int w)
{
ver[++tit] = y;
Edge[tit] = w;
nxt[tit] = head[x];
head[x] = tit;
}
struct qwq
{
int f,fanga;
friend qwq operator+(qwq a,qwq b)
{
return (qwq){a.f + b.f , a.fanga + b.fanga};
}
inline friend qwq Max(qwq a,qwq b)
{
if(a.f == b.f)
{
if(a.fanga > b.fanga) return a;
else return b;
}
if(a.f > b.f) return a;
else return b;
}
}dp[5][np];
// dp[d][i] 表示前以 i 为根的树,度数为 d 的
int n,k,sakura;
inline void dfs(int x,int ff)
{
dp[0][x] = (qwq){0,0},dp[1][x] = (qwq){0,0},dp[2][x] = (qwq){sakura,1};
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
dfs(v,x);
dp[2][x] = Max(dp[2][x] + dp[3][v],dp[1][x] + dp[1][v] + (qwq){Edge[i] + sakura,1});
dp[1][x] = Max(dp[1][x] + dp[3][v],dp[0][x] + dp[1][v] + (qwq){Edge[i],0});
dp[0][x] = dp[0][x] + dp[3][v];
}
dp[3][x] = Max(dp[0][x],Max(dp[1][x] + (qwq){sakura,1},dp[2][x]));
}
inline void judging(int x)
{
sakura = x;
dfs(1,0);
}
signed main()
{
read(n),read(k);
k++;
for(int i=1,a,b,w;i<=n - 1;i ++ )
{
read(a),read(b),read(w);
add(a,b,w);
add(b,a,w);
}
int l = -1e8,r = 1e8,Ans=0;
while(l <= r)
{
int mid = l + r >> 1;
judging(mid);
if(dp[3][1].fanga >= k)
{
Ans = dp[3][1].f - k * mid;
// printf("%lld %lld\n",dp[3][1].fanga,Ans);
r = mid - 1;
}
else l = mid + 1;
}
printf("%lld",Ans);
}
P6246 [IOI2000] 邮局 加强版
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pii pair<int,int>
template<typename _T>
inline void read(_T &x)
{
x=0;char s=getchar();int f=1;
while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
x*=f;
}
const int np = 1e6 + 5;
int a[np];
int sum[np],n,k;
inline int Abs(int x)
{
return x < 0?-x:x;
}
inline int calc(int l,int r)
{
int op = l + r >> 1;
return a[op] * (op - l + 1) - sum[op]+sum[l - 1] + Abs(a[op] * (r-op+1) - sum[r] + sum[op - 1]);
}
struct qwq
{
int f,fanga;
friend qwq operator+(qwq a,qwq b)
{
return (qwq){a.f + b.f,a.fanga + b.fanga};
}
friend bool operator<(qwq a,qwq b)
{
if(a.f == b.f) return a.fanga < b.fanga;
else return a.f < b.f;
}
}dp[np];
int sakura;
// int l_[2333],r_[2333],juec[2333];
struct qaq
{
int l_,r_,juec;
// int nx
}que[np * 2];
int top = 0;
inline int binary(qaq u,int op)
{
int l = u.l_,r = u.r_,opt = u.juec,ans = u.r_ + 1;//l <= op?op:0;
while(l <= r)
{
int mid = l + r >> 1;
if(dp[op] + (qwq){calc(op + 1,mid) + sakura,1} < dp[opt] + (qwq){calc(opt + 1,mid) + sakura,1}) ans = mid,r = mid - 1;
else l = mid + 1;
}
return ans;
}
inline void solve()
{
int head = 1,tail = 1;
dp[0] = (qwq){0,0};
que[head] = (qaq){1,n,0};
for(int i=1;i<=n;i++)
{
while(head < tail && que[head].r_ < i) head++;
int j = que[head].juec;
dp[i] = dp[j] + (qwq){calc(j + 1,i) + sakura,1};// + sakura;
int spilt = 0;
while(head < tail && que[tail].l_ == binary(que[tail],i)) spilt = que[tail].l_,tail--;
spilt = binary(que[tail],i);
if(spilt)
{
que[tail].r_ = spilt - 1;
// printf("%lld ",spilt);
que[++tail] = (qaq){spilt,n,i};
}
}
// if(sakura == 0)
// for(int i=1;i<=n;i++)
// {
// printf("%lld ",dp[i].f);
// }
// printf("\n");
}
namespace subtask{
int fp[500][4333];
inline int bbinary(int c,qaq u,int op)
{
int l = u.l_,r = u.r_,opt = u.juec,ans = u.r_ + 1;//l <= op?op:0;
while(l <= r)
{
int mid = l + r >> 1;
if(fp[c - 1][op] + calc(op + 1,mid) < fp[c - 1][opt] + calc(opt + 1,mid)) ans = mid,r = mid - 1;
else l = mid + 1;
}
return ans;
}
inline void solve1(int c)
{
int head = 1,tail = 1;
fp[c][0] = 0;
que[head] = (qaq){1,n,0};
for(int i=1;i<=n;i++)
{
while(head < tail && que[head].r_ < i) head++;
int j = que[head].juec;
fp[c][i] = fp[c - 1][j] + calc(j + 1,i);// + sakura;
int spilt = 0;
while(head < tail && que[tail].l_ == bbinary(c,que[tail],i)) spilt = que[tail].l_,tail--;
// int spilt = 0;
spilt = bbinary(c,que[tail],i);
if(spilt)
{
que[tail].r_ = spilt - 1;
// printf("%lld ",spilt);
que[++tail] = (qaq){spilt,n,i};
}
}
}
inline void Main()
{
memset(fp,0x3f,sizeof(fp));
fp[1][0] = 0;
for(int i=1;i<=n;i++) fp[1][i] = calc(1,i);
for(int i=2;i <= k;i++)
{
solve1(i);
// for(int j=1;j<=n;j++)
// printf("%lld ",fp[i][j]);
// printf("\n");
}
printf("%lld",fp[k][n]);
}
}
inline void judging(int x)
{
sakura = x;
solve();
}
signed main()
{
read(n),read(k);
for(int i=1;i<=n;i++)
{
read(a[i]);
sum[i] = sum[i - 1] + a[i];
}
int l = -1e8,r = 1e8,Ans(0);
while(l <= r)
{
int mid = l + r >> 1;
judging(mid);
// printf("%lld %lld\n",dp[n].f,dp[n].fanga);
if(dp[n].fanga <= k)
{
Ans = dp[n].f - k * sakura;
// printf("%lld\n",Ans);
r = mid - 1;
}
else l = mid + 1;
}
printf("%lld",Ans);
}
据说这个东西还能用来优化个模拟费用流啥的,笑死,根本不会费用流。