斜率优化
upd: 2021/07/20
学到一个简单点的方法。
还是以Print Article 为例。
给出 \(n\), \(m\) 和数列,将 \(n\) 个数分成任意段,每一段贡献是 \(sum^2+m\),求最小总贡献
\(f_i = min(f_j + m + {(s_i-s_j)}^2)\)
旧:devinwang的斜率优化入门题单简要题解
大张旗鼓开了个博客结果博客主要内容是“常规操作,过”TAT
板子
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) <= i) h++;
f[i] = f[q[h]] + a[i] + i * (s[i] - s[q[h]]);
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
Print Article
题意
给出 \(n\), \(m\) 和数列
将 \(n\) 个数分成任意段,每一段贡献是 \(sum^2+m\),求最小总贡献
题解
当 \(k < j < i\) 时,如果 \(j\) 优于 \(k\),那么
不妨 \(y(i) = f_i + s^2_i, x(i) = 2 s_i\)
则 \(\frac {y(j) - y(k)} {x(j) - x(k)} \le s_i\)
即 斜率 \(k \le s_i\)
令 \(g(j, k) = \frac {(f_j+s^2_j) - (f_k + s^2_k)} {2(s_j - s_k)}\),
当 \(g(i, j) \le g(j, k)\) 时,
若 \(g(i, j) \le s_i\) 则 \(i\) 优于 \(j\) ,\(j\) 没有存在必要
若 \(g(i, j) > s_i\) 则 \(j\) 优于 \(i\) ,同样,\(g(j,k) > s_i\) , \(k\) 优于 \(j\) , \(j\) 没有存在必要
因此,剔除所有 \(g(i, j) \le g(j, k)\),维护一个类似凸包的东西。
代码
const int N = 500010;
int n, m, f[N], s[N], q[N];
int gety(int j, int k){
return f[j] + s[j] * s[j] - f[k] - s[k] * s[k];
}
int getx(int j, int k){
return 2 * (s[j] - s[k]);
}
int main(){
while(scanf("%d%d", &n, &m) == 2){
for(int i = 1; i <= n; i++)
scanf("%d", &s[i]), s[i] += s[i - 1];
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++){
while(h < t && gety(q[h + 1], q[h]) <= s[i] * getx(q[h + 1], q[h])) h++;
f[i] = f[q[h]] + (s[i] - s[q[h]]) * (s[i] - s[q[h]]) + m;
while(h < t && gety(i, q[t]) * getx(q[t], q[t - 1]) <= getx(i, q[t]) * gety(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%d\n", f[n]);
}
return 0;
}
用slope的写法:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 500010;
int n, m, q[N];
ll f[N], s[N];
ll getx(int x, int y) {
return s[y] - s[x];
}
ll gety(int x, int y) {
return f[y] + s[y] * s[y] - f[x] - s[x] * s[x];
}
double slope(int x, int y) {
if(getx(x, y) == 0) {
return (gety(x, y) >= 0) ? 1 : -1;
}
return 1.0 * gety(x, y) / getx(x, y);
}
int main(){
while(scanf("%d%d", &n, &m) == 2) {
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1];
int h = 1, t = 0; q[++t] = 0;
for(int i = 1; i <= n; i++) {
while(h < t && slope(q[h], q[h + 1]) <= 2 * s[i]) h++;
f[i] = f[q[h]] + (s[i] - s[q[h]]) * (s[i] - s[q[h]]) + m;
while(h < t && slope(q[t - 1], q[t]) >= slope(q[t], i)) t--;
q[++t] = i;
}
printf("%lld\n", f[n]);
}
return 0;
}
/*
f[i]-s[i]^2-m=f[j]+s[j]^2-2*s[i]*s[j]
k=2*s[i],x=s[j]
y=f[j]+s[j]^2
b=f[i]-s[i]^2-m
b=y-kx
y=kx+b
k从1~n单调不降
*/
玩具装箱
同Print Article,
const int N = 500010;
int n, q[N];
ll f[N], s[N], L;
ll gety(int j, int k){
return f[j] + (j + s[j]) * (j + s[j]) - f[k] - (k + s[k]) * (k + s[k]);
}
ll getx(int j, int k){
return 2ll * (s[j] + j - s[k] - k);
}
int main(){
scanf("%d%lld", &n, &L);
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1];
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++){
while(h < t && gety(q[h + 1], q[h]) <= (i + s[i] - L - 1) * getx(q[h + 1], q[h])) h++;
f[i] = f[q[h]] + (i - q[h] + s[i] - s[q[h]] - L - 1) * (i - q[h] + s[i] - s[q[h]] - L - 1);
while(h < t && gety(i, q[t]) * getx(q[t], q[t - 1]) <= getx(i, q[t]) * gety(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", f[n]);
return 0;
}
用slope的写法
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 500010;
int n, m, q[N];
ll f[N], s[N];
ll getx(int x, int y) {
return y + s[y] - (x + s[x]);
}
ll gety(int x, int y) {
return f[y] + (y + s[y] + m) * (y + s[y] + m) - f[x] - (x + s[x] + m) * (x + s[x] + m);
}
double slope(int x, int y) {
// if(getx(x, y) == 0) {
// return (gety(x, y) >= 0) ? 1 : -1;
// }
return 1.0 * gety(x, y) / getx(x, y);
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1];
int h = 1, t = 0; q[++t] = 0;
for(int i = 1; i <= n; i++) {
while(h < t && slope(q[h], q[h + 1]) <= 2 * (i + s[i])) h++;
f[i] = f[q[h]] + (i - q[h] - 1 + s[i] - s[q[h]] - m) * (i - q[h] - 1 + s[i] - s[q[h]] - m);
while(h < t && slope(q[t - 1], q[t]) >= slope(q[t], i)) t--;
q[++t] = i;
}
printf("%lld\n", f[n]);
return 0;
}=
锯木厂选址
题意
从山顶上到山底下沿着一条直线种植了 \(n\) 棵老树。当地的政府决定把他们砍下来。为了不浪费任何一棵木材,树被砍倒后要运送到锯木厂。
木材只能朝山下运。山脚下有一个锯木厂。另外两个锯木厂将新修建在山路上。你必须决定在哪里修建这两个锯木厂,使得运输的费用总和最小。假定运输每公斤木材每米需要一分钱。
你的任务是编写一个程序,读入树的个数和他们的重量与位置,计算最小运输费用。
懒得翻译(((
题解
考虑总的贡献,山顶到山脚从小到大读入,假设 \(j < i\),
令 \(d_j\) 为到山脚的距离。
选 \(i\) 的时候,最小总贡献
令 \(s_i\) 为 \(w_i\) 的前缀和
当 \(k < j\) 且 \(j\) 优于 \(k\)
当 \(slope(i, j) > slope(j, k)\) 时 \(j\) 无用。
维护一个下凸包。
代码
const int N = 2e5 + 10;
int n, q[N];
ll s[N], w[N], d[N], sum;
ll gety(int j, int k){
return d[j] * s[j] - d[k] * s[k];
}
ll getx(int j, int k){
return s[j] - s[k];
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
ll calc(int j, int i){
return sum - d[j] * s[j] - d[i] * (s[i] - s[j]);
}
int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%lld%lld", &w[i], &d[i]), s[i] = s[i - 1] + w[i];
for(int i = n; i >= 1; i--)
d[i] += d[i + 1], sum += d[i] * w[i];
int h = 0, t = -1; q[++t] = 0;
ll ans = 1e18;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) > d[i]) h++;
ans = min(ans, calc(q[h], i));
while(h < t && slope(i, q[t]) > slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", ans);
return 0;
}
仓库建设
和前一题差不多吧(((
const int N = 1e6 + 10;
int n, q[N];
ll s[N], w[N], d[N], g[N], f[N], c[N];
ll gety(int j, int k){
return f[j] + g[j] - f[k] - g[k];
}
ll getx(int j, int k){
return s[j] - s[k];
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%lld%lld%lld", &d[i], &w[i], &c[i]),
s[i] = s[i - 1] + w[i], g[i] = g[i - 1] + d[i] * w[i];
int h = 0, t = -1; q[++t] = 0;
ll ans = 1e18;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) <= d[i]) h++;
f[i] = f[q[h]] + c[i] + (s[i] - s[q[h]]) * d[i] - g[i] + g[q[h]];
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", f[n]);
return 0;
}
土地购买
题解
按照从小到大排序,\(a\) 第一关键字,\(b\)第二关键字,去除包含。
现在 \(a\) 始终递增,\(b\) 始终递减。
显然取连续的一段才是最优的。
然后常规操作
代码
const int N = 1e6 + 10;
int n, q[N];
ll f[N];
struct node{
ll a, b;
bool operator < (const node x) const {
return (a == x.a) ? b < x.b : a < x.a;
}
}po[N];
ll gety(int j, int k){
return f[j] - f[k];
}
ll getx(int j, int k){
return po[k + 1].b - po[j + 1].b;
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%lld%lld", &po[i].a, &po[i].b);
sort(po + 1, po + n + 1);
int cnt = 0;
for(int i = 1; i <= n; i++){
while(cnt && po[cnt].b <= po[i].b) cnt--;
po[++cnt] = po[i];
}
int h = 0, t = -1; q[++t] = 0;
ll ans = 1e18;
for(int i = 1; i <= cnt; i++){
while(h < t && slope(q[h + 1], q[h]) <= po[i].a) h++;
f[i] = f[q[h]] + po[i].a * po[q[h] + 1].b;
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", f[cnt]);
return 0;
}
特别行动队
啊这就常规操作
然后发现常数项最好拖到主函数那里面,否则会炸精度(((
const int N = 1e6 + 10;
int n, q[N];
ll f[N], s[N], a, b, c;
ll gety(int j, int k){
return f[j] + a * s[j] * s[j] - b * s[j] - f[k] - a * s[k] * s[k] + b * s[k];
}
ll getx(int j, int k){
return s[j] - s[k];
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d", &n);
scanf("%lld%lld%lld", &a, &b, &c);
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1];
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) > 2 * a * s[i]) h++;
ll x = s[i] - s[q[h]]; f[i] = f[q[h]] + a * x * x + b * x + c;
while(h < t && slope(i, q[t]) > slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", f[n]);
return 0;
}
序列分割
题解
发现切的顺序对答案无影响
然后常规操作
注意到 \(s_j = s_k\) 时,要特判,否则会除以0 RE
代码
const int N = 1e5 + 10;
int n, k, q[N];
ll f[N], s[N], g[N], pre[210][N];
ll gety(int j, int k){
return g[j] - s[j] * s[j] - (g[k] - s[k] * s[k]);
}
ll getx(int j, int k){
return s[k] - s[j];
}
double slope(int j, int k){
if(s[j] == s[k]) return -1e18;
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1];
for(int j = 1; j <= k; j++){
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++) g[i] = f[i], f[i] = 0;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) <= s[i]) h++;
f[i] = g[q[h]] + s[q[h]] * (s[i] - s[q[h]]); pre[j][i] = q[h];
// cout<<i<<"*"<<q[h]<<endl;
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
}
printf("%lld\n", f[n]);
int j = k, i = n;
while(j) {
printf("%d ", pre[j][i]);
i = pre[j][i], j--;
}
return 0;
}
「SDOI2016」征途
题解
一通推柿子答案为
一通常规操作
注意代码中注释的那行,0是显然错误的
代码
const int N = 3010 + 10;
int n, q[N];
ll f[N], g[N], s[N] ,m;
ll gety(int j, int k){
return g[j] + s[j] * s[j] - (g[k] + s[k] * s[k]);
}
ll getx(int j, int k){
return s[j] - s[k];
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d%lld", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%lld", &s[i]), s[i] += s[i - 1], f[i] = s[i] * s[i];
for(int j = 1; j < m; j++){
int h = 0, t = -1; q[++t] = j;//q[++t] = 0;
for(int i = 1; i <= n; i++) g[i] = f[i], f[i] = 0;
for(int i = j + 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) <= 2 * s[i]) h++;
ll x = s[i] - s[q[h]]; f[i] = g[q[h]] + x * x;
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
}
printf("%lld\n", - s[n] * s[n] + m * f[n]);
return 0;
}
小P的牧场
常规操作。。。
const int N = 1e6 + 10;
int n, q[N];
ll a[N], b[N], f[N], s[N], sum;
ll gety(int j, int k){
return f[j] - f[k];
}
ll getx(int j, int k){
return s[j] - s[k];
}
double slope(int j, int k){
return 1.0 * gety(j, k) / getx(j, k);
}
int main(){
scanf("%d%", &n);
for(int i = 1; i <= n; i++)
scanf("%lld", &a[i]);
for(int i = 1; i <= n; i++)
scanf("%lld", &b[i]),
s[i] = s[i - 1] + b[i], sum += i * b[i];
int h = 0, t = -1; q[++t] = 0;
for(int i = 1; i <= n; i++){
while(h < t && slope(q[h + 1], q[h]) <= i) h++;
f[i] = f[q[h]] + a[i] + i * (s[i] - s[q[h]]);
while(h < t && slope(i, q[t]) <= slope(q[t], q[t - 1])) t--;
q[++t] = i;
}
printf("%lld\n", f[n] - sum);
return 0;
}
/*
4
2 4 2 4
3 1 4 2
*/