浅谈wqs二分
论文
算法讲解
例题
首先不考虑限制是一个很简单的斜率优化板子
加上
k
k
k之后再用斜率优化就是
O
(
n
k
)
O(nk)
O(nk)的
如果
k
,
n
k,n
k,n同阶显然做不了
考虑怎么优化这个问题
这个时候就要用wqs二分了
wqs二分
设
f
(
k
)
f(k)
f(k)表示分成
k
k
k段的答案
通过打表严格证明可以发现
(
x
,
f
(
x
)
)
(x,f(x))
(x,f(x))是个凸壳 (斜率单调)
先假设这是个上凸壳
二分一个
m
i
d
mid
mid,表示直线的斜率
然后用这条直线去切这个凸壳,假设交点为
(
x
,
f
(
x
)
)
(x,f(x))
(x,f(x))
显然这个点可以表示为
(
x
,
m
i
d
∗
x
+
g
(
m
i
d
)
)
(x,mid*x+g(mid))
(x,mid∗x+g(mid)),
g
(
m
i
d
)
为
截
距
g(mid)为截距
g(mid)为截距(带入直线)
那么
f
(
x
)
=
m
i
d
∗
x
+
g
(
m
i
d
)
f(x) = mid * x + g(mid)
f(x)=mid∗x+g(mid)
f
(
x
)
=
m
i
d
∗
x
+
g
(
m
i
d
)
f(x) = mid * x + g(mid)
f(x)=mid∗x+g(mid)
g
(
m
i
d
)
=
f
(
x
)
−
m
i
d
∗
x
g(mid) = f(x) - mid * x
g(mid)=f(x)−mid∗x
然后这个
g
(
m
i
d
)
g(mid)
g(mid)可以直接用斜率优化跑出来,因为要满足的也是g(x)最大/最小。
最后在通过
g
(
m
i
d
)
+
x
∗
m
i
d
g(mid)+x*mid
g(mid)+x∗mid算出
f
(
k
)
f(k)
f(k)
显然这个可以二分
m
i
d
mid
mid,看
g
(
m
i
d
)
g(mid)
g(mid)分成了几段,然后继续二分
m
i
d
mid
mid
计算
g
g
g的时候只需要转移的时候每次多减一个
m
i
d
mid
mid就行了
所以时间复杂度为
O
(
n
l
o
g
N
)
O(nlogN)
O(nlogN)
给个例题讲讲吧
luogu P4983 忘情
把这一坨柿子化简一下就是
(
1
+
∑
x
i
)
2
(1+\sum x_i)^2
(1+∑xi)2
f
i
=
max
(
f
j
+
(
s
i
−
s
j
+
1
)
2
)
f_i=\max(f_j+(s_i-s_j+1)^2)
fi=max(fj+(si−sj+1)2)
=
max
(
f
j
+
(
s
i
+
1
)
2
−
2
∗
(
s
i
+
1
)
∗
s
j
+
s
j
2
)
\ \ \ \ =\max(f_j+(s_i+1)^2-2*(s_i +1)*s_j+{s_j}^2)
=max(fj+(si+1)2−2∗(si+1)∗sj+sj2)
=
max
(
f
j
−
2
∗
(
s
i
+
1
)
∗
s
j
+
s
j
2
)
+
(
s
i
+
1
)
2
\ \ \ \ =\max(f_j-2*(s_i +1)*s_j+{s_j}^2)+(s_i+1)^2
=max(fj−2∗(si+1)∗sj+sj2)+(si+1)2
f
i
−
(
s
i
+
1
)
2
+
2
(
s
i
+
1
)
s
j
=
f
j
+
s
j
2
\ \ \ \ f_i -(s_i+1)^2 + 2(s_i +1)s_j=f_j+{s_j}^2
fi−(si+1)2+2(si+1)sj=fj+sj2
设
y
j
=
f
j
+
s
j
2
y_j=f_j+{s_j}^2
yj=fj+sj2
x
j
=
s
j
x_j=s_j
xj=sj
k
i
=
2
(
s
i
+
1
)
k_i=2(s_i+1)
ki=2(si+1)
c
i
=
f
i
−
(
s
i
+
1
)
2
c_i=f_i -(s_i+1)^2
ci=fi−(si+1)2
y
j
=
k
i
x
j
+
c
i
y_j=k_ix_j+c_i
yj=kixj+ci
要
c
i
c_i
ci最大,显然是个斜率优化
直接rush就好了
至于分成
k
k
k段
通过打表我不会的证明可以发现
f
(
n
,
k
)
f(n,k)
f(n,k)是凸的
然后直接wqs二分即可
写之前要想一下mid是加上还是减去
这一题想想可以发现,加得越多,分得段数越少
所以如果分得段数
>
k
>k
>k就要
l
=
m
i
d
l=mid
l=mid
code:
#include<bits/stdc++.h>
#define ll long long
#define N 400005
using namespace std;
ll f[N], s[N], cnt[N], q[N];
int n, m;
double y(int i) {return f[i] + s[i] * s[i]; }
double x(int i) {return s[i]; }
double k(int i) {return 2.0 * (s[i] + 1); }
double slope(int i, int j) {
return (y(i) - y(j)) / (x(i) - x(j));
}
int check(ll mid) {//斜率优化计算DP值 g(mid)
int l = 0, r = 0;
for(int i = 1; i <= n; i ++) {
while(l < r && slope(q[l], q[l + 1]) < k(i)) l ++;
int j = q[l];
f[i] = f[j] + (s[i] - s[j] + 1) * (s[i] - s[j] + 1) + mid;//加mid
cnt[i] = cnt[j] + 1;
while(l < r && slope(q[r - 1], q[r]) > slope(q[r], i)) r --;
q[++ r] = i;
}
return cnt[n] > m;//注意这里,如果分得段数>m说明加得不够
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++) scanf("%lld", &s[i]), s[i] += s[i - 1];
ll l = 0, r = (1ll << 61);
while(l + 1 < r) {//二分斜率
ll mid = (l + r) >> 1;
if(check(mid)) l = mid;
else r = mid;
}
check(r);//最后算一下段数
printf("%lld", f[n] - r * m);//减去多加的
return 0;
}
先不考虑
k
k
k
考虑反过来DP
f
i
表
示
还
剩
下
i
个
人
的
最
大
奖
金
f_i表示还剩下i个人的最大奖金
fi表示还剩下i个人的最大奖金
显然
f
i
=
max
(
f
j
+
i
−
j
i
)
f_i=\max(f_j+\frac{i-j}{i})
fi=max(fj+ii−j)
f
i
−
1
i
(
i
−
j
)
=
f
j
f_i-\frac{1}{i}(i-j)=f_j
fi−i1(i−j)=fj
f
i
−
1
+
1
i
(
j
)
=
f
j
f_i-1+\frac{1}{i}(j)=f_j
fi−1+i1(j)=fj
x
j
=
j
x_j=j
xj=j
y
j
=
f
j
y_j=f_j
yj=fj
k
i
=
1
i
k_i=\frac{1}{i}
ki=i1
c
i
=
f
i
−
1
c_i=f_i-1
ci=fi−1
斜率优化即可
加上
k
k
k的限制就是一个wqs二分
对于这一题,减得越多,段数就会越少,二分的时候要注意一下
code:
#include<bits/stdc++.h>
#define N 200005
using namespace std;
double f[N];
int cnt[N], q[N], n, k;
double y(int i) {return f[i]; }
double x(int i) {return i; }
double slope(int i, int j) {
return (y(i) - y(j)) / (x(i) - x(j));
}
int check(double mid) {
int l = 1, r = 1; q[l] = 0;
for(int i = 1; i <= n; i ++) {
while(l < r && slope(q[l], q[l + 1]) > 1.0 / i) l ++;
int j = q[l];
f[i] = f[j] + (i - j) * 1.0 / i - mid;
cnt[i] = cnt[j] + 1;
while(l < r && slope(q[r - 1], q[r]) < slope(q[r], i)) r --;
q[++ r] = i;
}
return cnt[n] >= k;
}
int main() {
scanf("%d%d", &n, &k);
double l = 0, r = 1e6;
for(int i = 1; i <= 200; i ++) {
double mid = (l + r) * 1.0 / 2.0;
if(check(mid)) l = mid;
else r = mid;
}
check(l);
printf("%.9f", f[n] + l * k);
return 0;
}