非常简单地理解带权二分(wqs二分)
非常感性简单地理解带权二分(又名 wqs 二分),尽管不是很严谨,如有错误请指出
\(\Large\natural\) Gosha is hunting / 原题链接 / 更好阅读体验
解法
设我们有 \(A\) 个红球和 \(B\) 个蓝球,用红球抓 \(i\) 号胖可丁的概率是 \(a_i\),用蓝球抓 \(i\) 号胖可丁的概率是 \(b_i\)。
首先我们有一个暴力 \(\text{DP}\):设 \(f_{i,j,k}\) 为 \(1\sim i\) 这些胖可丁中,用了 \(j\) 个红球和 \(k\) 个蓝球的最大期望。
\(\max\) 中依次代表着对于第 \(i\) 个胖可丁:不用精灵球;只用红球;只用蓝球;两种球都用。
答案是 \(f_{n,A,B}\)。当然这是 \(O(n^3)\) 的,显然会超时。
我们发现,用了 \((x+1)\) 个精灵球肯定会比用 \(x\) 个精灵球得到的期望更大。所以它是一个单调的东西。事实上,如果 \(f_{i,j,k}\) 中 \(i,k\) 都是不变常数,即它是关于 \(j\) 的函数,那么它是一个凸函数。因为你第一个红球肯定会抓期望最大的胖可丁,第二个红球肯定会抓期望次大的胖可丁……这样增长的速度就会越来越慢。当然, 如果 \(f_{i,j,k}\) 中 \(i,j\) 都是不变常数,即它是关于 \(k\) 的函数也同理。所以这个凸函数的性质意味着我们可以用带权二分。
根据带权二分的思想,我们需要将 \(f_{i,j,k}\) 简化为 \(f_{i,j}\)。我们可以假设现在蓝球是免费的,想用多少就用多少(好耶!),那这样就会有转移方程:
时间复杂度 \(O(n^2)\)。
但是,毫无疑问,这样的话所有的转移肯定都会加上蓝球的贡献,毕竟免费的肯定越用多越好。所以这样想:我们给蓝球套上一个价格,即,你需要 \(k\) 元才能买到一个蓝球。这样的话, \(\text{DP}\) 方程就是这样的了:
这样肯定就会有一些贫穷的 \(f_{i,j}\) 不能加上蓝球的贡献。
于是,我们在 \(\text{DP}\) 过程中同时记录一下使用蓝球的个数 \(cnt\)。如果 \(cnt>B\) 那么就是蓝球供不应求了。我们肯定会让 \(k\) 更大一点,也就是让蓝球更昂贵,更奢侈,让更多可怜的 \(f_{i,j}\) 买不到蓝球,以减少 \(cnt\)。
相反地,如果 \(cnt<B\),那么就是蓝球生产过剩、供大于求。我们必须要让 \(k\) 更小一点,也就是让蓝球更便宜,以增加 \(cnt\)。
如果 \(cnt=B\) ,恭喜,我们恰好用了 \(B\) 个蓝球,这应该就是合适的价格了。
答案就是 \(f_{n,A}\)……不对,蓝球事实上是免费的。所以我们要把那些钱还回去,答案就是 \(f_{n,A}+cnt\times k\)。
于是我们可以二分这个价钱 \(k\)。一开始二分的边界是 \(l=0,r=1\),因为概率都是大于等于零、小于等于一的。
时间复杂度变为 \(O(n^2\log V)\)。其中 \(\log V\) 是二分中产生的。
所以你也可以继续给红球也依法炮制,给它安一个价格。而且 \(\text{DP}\) 也不复存在,变成了一个贪心。更令人喜出望外的是,这样时间复杂度甚至可以变成 \(O(n\log^2 V)\),可以非常充裕地通过本题!
注意最后答案是 \(f_{n,A,B}+A\times k_a +B\times k_b\),而不是\(f_{n,A,B}+cnt_a\times k_a +cnt_b\times k_b\)。原因的话需要去以斜率之类的方式去说明。由于本篇题解只是感性理解,而非用斜率去证明,所以大家就把这个东西记下来吧。
代码
注意精度……这个题目很在乎精度。我这里是 eps=1e-8
。
而且从这个题目我知道了:人傻常数大,人傻常数低……(同校大佬一次过,时间还少)
码风毒瘤见谅。
#include<bits/stdc++.h>
#define rep(i,x,y) for(int i=x;i<=y;++i)
#define lod long double
using namespace std;
const int n7=2021;const lod eps=1e-8;
int n,A,B,cnt1,cnt2;
lod a[n7],b[n7],c[n7],tot,val1,val2,ans;
void check(lod rmb1,lod rmb2){
cnt1=cnt2=0,tot=0;
rep(i,1,n){
bool flag1=0,flag2=0;lod tmp=0;
if(a[i]-rmb1>tmp+eps)tmp=a[i]-rmb1,flag1=1,flag2=0;
if(b[i]-rmb2>tmp+eps)tmp=b[i]-rmb2,flag1=0,flag2=1;
if(c[i]-rmb1-rmb2>tmp+eps)tmp=c[i]-rmb1-rmb2,flag1=1,flag2=1;
if(flag1)cnt1++;
if(flag2)cnt2++;
tot+=tmp;
}
}
int main(){
cin>>n>>A>>B;
rep(i,1,n)cin>>a[i];
rep(i,1,n)cin>>b[i];
rep(i,1,n)c[i]=1-(1-a[i])*(1-b[i]);
lod l1=0,r1=1;
while(l1+eps<r1){
lod mid1=(l1+r1)/2;
lod l2=0,r2=1;
while(l2+eps<r2){
lod mid2=(l2+r2)/2;
check(mid1,mid2);
val2=mid2;
if(cnt2==B)break;
if(cnt2>B)l2=mid2;
else r2=mid2;
}
val1=mid1;
if(cnt1==A)break;
if(cnt1>A)l1=mid1;
else r1=mid1;
}
ans=tot+val1*A+val2*B;
cout<<ans;
return 0;
}