浅谈wqs二分

论文

浅析一类二分方法

算法讲解

例题

在这里插入图片描述
首先不考虑限制是一个很简单的斜率优化板子
加上 k k k之后再用斜率优化就是 O ( n k ) O(nk) O(nk)
如果 k , n k,n k,n同阶显然做不了
考虑怎么优化这个问题
这个时候就要用wqs二分了

wqs二分

f ( k ) f(k) f(k)表示分成 k k k段的答案
通过打表严格证明可以发现 ( x , f ( x ) ) (x,f(x)) (x,f(x))是个凸壳 (斜率单调)

先假设这是个上凸壳
二分一个 m i d mid mid,表示直线的斜率
然后用这条直线去切这个凸壳,假设交点为 ( x , f ( x ) ) (x,f(x)) (x,f(x))
显然这个点可以表示为 ( x , m i d ∗ x + g ( m i d ) ) (x,mid*x+g(mid)) (x,midx+g(mid)), g ( m i d ) 为 截 距 g(mid)为截距 g(mid)(带入直线)
那么 f ( x ) = m i d ∗ x + g ( m i d ) f(x) = mid * x + g(mid) f(x)=midx+g(mid)
f ( x ) = m i d ∗ x + g ( m i d ) f(x) = mid * x + g(mid) f(x)=midx+g(mid)
g ( m i d ) = f ( x ) − m i d ∗ x g(mid) = f(x) - mid * x g(mid)=f(x)midx

然后这个 g ( m i d ) g(mid) g(mid)可以直接用斜率优化跑出来,因为要满足的也是g(x)最大/最小。
最后在通过 g ( m i d ) + x ∗ m i d g(mid)+x*mid g(mid)+xmid算出 f ( k ) f(k) f(k)
显然这个可以二分 m i d mid mid,看 g ( m i d ) g(mid) g(mid)分成了几段,然后继续二分 m i d mid mid
计算 g g g的时候只需要转移的时候每次多减一个 m i d mid mid就行了
所以时间复杂度为 O ( n l o g N ) O(nlogN) O(nlogN)

给个例题讲讲吧
luogu P4983 忘情

把这一坨柿子化简一下就是
( 1 + ∑ x i ) 2 (1+\sum x_i)^2 (1+xi)2
f i = max ⁡ ( f j + ( s i − s j + 1 ) 2 ) f_i=\max(f_j+(s_i-s_j+1)^2) fi=max(fj+(sisj+1)2)
     = max ⁡ ( f j + ( s i + 1 ) 2 − 2 ∗ ( s i + 1 ) ∗ s j + s j 2 ) \ \ \ \ =\max(f_j+(s_i+1)^2-2*(s_i +1)*s_j+{s_j}^2)     =max(fj+(si+1)22(si+1)sj+sj2)
     = max ⁡ ( f j − 2 ∗ ( s i + 1 ) ∗ s j + s j 2 ) + ( s i + 1 ) 2 \ \ \ \ =\max(f_j-2*(s_i +1)*s_j+{s_j}^2)+(s_i+1)^2     =max(fj2(si+1)sj+sj2)+(si+1)2
     f i − ( s i + 1 ) 2 + 2 ( s i + 1 ) s j = f j + s j 2 \ \ \ \ f_i -(s_i+1)^2 + 2(s_i +1)s_j=f_j+{s_j}^2     fi(si+1)2+2(si+1)sj=fj+sj2

y j = f j + s j 2 y_j=f_j+{s_j}^2 yj=fj+sj2
x j = s j x_j=s_j xj=sj
k i = 2 ( s i + 1 ) k_i=2(s_i+1) ki=2(si+1)
c i = f i − ( s i + 1 ) 2 c_i=f_i -(s_i+1)^2 ci=fi(si+1)2
y j = k i x j + c i y_j=k_ix_j+c_i yj=kixj+ci
c i c_i ci最大,显然是个斜率优化
直接rush就好了

至于分成 k k k
通过打表我不会的证明可以发现 f ( n , k ) f(n,k) f(n,k)是凸的
然后直接wqs二分即可
写之前要想一下mid是加上还是减去
这一题想想可以发现,加得越多,分得段数越少
所以如果分得段数 > k >k >k就要 l = m i d l=mid l=mid

code:

#include<bits/stdc++.h>
#define ll long long
#define N 400005
using namespace std;
ll f[N], s[N], cnt[N], q[N];
int n, m;
double y(int i) {return f[i] + s[i] * s[i]; }
double x(int i) {return s[i]; }
double k(int i) {return 2.0 * (s[i] + 1); }
double slope(int i, int j) {
	return (y(i) - y(j)) / (x(i) - x(j));
}
int check(ll mid) {//斜率优化计算DP值  g(mid) 
	int l = 0, r = 0;
	for(int i = 1; i <= n; i ++) {
		while(l < r && slope(q[l], q[l + 1]) < k(i)) l ++;
		int j = q[l];
		f[i] = f[j] + (s[i] - s[j] + 1) * (s[i] - s[j] + 1) + mid;//加mid 
		cnt[i] = cnt[j] + 1;
		while(l < r && slope(q[r - 1], q[r]) > slope(q[r], i)) r --;
		q[++ r] = i;
	}
	return cnt[n] > m;//注意这里,如果分得段数>m说明加得不够 
}
int main() {
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; i ++) scanf("%lld", &s[i]), s[i] += s[i - 1];
	ll l = 0, r = (1ll << 61);
	while(l + 1 < r) {//二分斜率 
		ll mid = (l + r) >> 1;
		if(check(mid)) l = mid;
		else r = mid;
	}	
	check(r);//最后算一下段数 
	printf("%lld", f[n] - r * m);//减去多加的 
	return 0;
} 

再来一题
P5308 [COCI2019] Quiz

先不考虑 k k k
考虑反过来DP
f i 表 示 还 剩 下 i 个 人 的 最 大 奖 金 f_i表示还剩下i个人的最大奖金 fii
显然
f i = max ⁡ ( f j + i − j i ) f_i=\max(f_j+\frac{i-j}{i}) fi=max(fj+iij)
f i − 1 i ( i − j ) = f j f_i-\frac{1}{i}(i-j)=f_j fii1(ij)=fj
f i − 1 + 1 i ( j ) = f j f_i-1+\frac{1}{i}(j)=f_j fi1+i1(j)=fj
x j = j x_j=j xj=j
y j = f j y_j=f_j yj=fj
k i = 1 i k_i=\frac{1}{i} ki=i1
c i = f i − 1 c_i=f_i-1 ci=fi1
斜率优化即可
加上 k k k的限制就是一个wqs二分
对于这一题,减得越多,段数就会越少,二分的时候要注意一下
code:

#include<bits/stdc++.h>
#define N 200005
using namespace std;
double f[N];
int cnt[N], q[N], n, k;
double y(int i) {return f[i]; }
double x(int i) {return i; }
double slope(int i, int j) {
	return (y(i) - y(j)) / (x(i) - x(j));
}
int check(double mid) {
	int l = 1, r = 1; q[l] = 0;
	for(int i = 1; i <= n; i ++) {
		while(l < r && slope(q[l], q[l + 1]) > 1.0 / i) l ++;
		int j = q[l];
		f[i] = f[j] + (i - j) * 1.0 / i - mid;
		cnt[i] = cnt[j] + 1;
		while(l < r && slope(q[r - 1], q[r]) < slope(q[r], i)) r --;
		q[++ r] = i; 
	}
	return cnt[n] >= k;
}
int main() {
	scanf("%d%d", &n, &k);
	double l = 0, r = 1e6;
	for(int i = 1; i <= 200; i ++) {
		double mid = (l + r) * 1.0 / 2.0;
		if(check(mid)) l = mid;
		else r = mid;
	}
	check(l);
	printf("%.9f", f[n] + l * k);
	return 0;
} 
posted @ 2020-09-16 19:15  lahlah  阅读(83)  评论(0编辑  收藏  举报