算法笔记--斜率优化dp
斜率优化是单调队列优化的推广
用单调队列维护递增的斜率
参考:https://www.cnblogs.com/ka200812/archive/2012/08/03/2621345.html
以例1举例说明:
转移方程为:dp[i] = min(dp[j] + (sum[i] - sum[j])^2 + C)
假设k < j < i, 如果从j转移过来比从k转移过来更优
那么 dp[j] + (sum[i] - sum[j])^2 + C < dp[k] + (sum[i] - sum[k])^2 + C
dp[j] - dp[k] < (sum[i] - sum[k])^2 - (sum[i] - sum[j])^2
dp[j] - dp[k] < -2*sum[i]*sum[k] + sum[k]*sum[k] + 2*sum[i]*sum[j] - sum[j]*sum[j]
dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k] < 2*sum[i]*(sum[j] - sum[k])
(dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k]) < 2*sum[i]
我们观察不等式左边, 它是个斜率的形式, 自变量x为sum, 函数f(x)为dp + sum*sum
我们记这个斜率为g[j, k] = (dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k])
说明1.如果g[j, k] < 2*sum[i] 表示对于dp[i], 从j转移过来比k更优, 反之k更优
说明2.下面我们来考虑着怎么从解集去掉多余的元素, 可以证明可能存在某些元素,无论怎样都不会是最优的,可以去掉这些多余的元素
假设k < j < i
结论:如果g[i, j] < g[j, k], 那么j可以去掉
证明:对于某个i, 如果g[i, j] < 2*sum[i], 那么i比j更优, 结论成立;
如果g[i, j] >= 2*sum[i], 那么g[j, k] > g[i, j] >= 2*sum[i], 那么k比j更优,结论成立.
证毕.
所以如果把所有g[i, j] < g[j, k]的情况中(后面斜率比前面斜率小的情况)的j都去掉, 那么我们就得到相邻两个元素的斜率递增的状况
如下图
下面来说明怎么维护这个解集:
用双端队列维护这个解集, 每次从后面加入元素时, 按照说明2的方式去掉多余元素,使的相邻元素之间构成的斜率保持单调
每次从前面找答案, 由于斜率单调递增, 所以最后一个小于2*sum[i]就是最优的解, 因为这个位置之前的g[i, j]都小于2*sum,
表示后面的比前面更优, 之后的g[i, j] 都大于2*sum, 表示前面的比后面更优, 所以这个点是极值点
又因为sum[i]也具有单调性, 所以下一个极值点的位置肯定大于等于当前极值点, 所以当前极值点之前的都可以从双端队列中移出
ps:所有说明中, k < j < i
例题1:HDU - 3507
思路:维护递增斜率g[i, j] = (dp[i] - dp[j] + sum[i]*sum[i] - sum[j]*sum[j]) / (sum[i] - sum[j])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 5e5 + 10; int a[N], n, m; LL sum[N], dp[N]; bool g(int k, int j, LL C) { return (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k]) <= C*(sum[j]-sum[k]); } bool gg(int k, int j, int i) { return (dp[i]-dp[j]+sum[i]*sum[i]-sum[j]*sum[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k])*(sum[i]-sum[j]); } deque<int> q; int main() { while(~scanf("%d %d", &n, &m)) { for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i] = sum[i-1]+a[i]; while(!q.empty()) q.pop_back(); q.push_back(0); for (int i = 1; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, 2*sum[i])) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j] + (sum[i]-sum[j])*(sum[i]-sum[j])+m; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, i)) ; else { q.push_back(b); break; } } q.push_back(i); } printf("%lld\n", dp[n]); } return 0; }
例题2:HDU - 1300
思路:维护递增斜率g[i, j] = (dp[i] - dp[j]) / (sum[i] - sum[j])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 100 + 10; int a[N], p[N], n, m, T; LL sum[N], dp[N]; bool g(int k, int j, LL C) { return (dp[j]-dp[k]) <= C*(sum[j]-sum[k]); } bool gg(int k, int j, int i) { return (dp[i]-dp[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k])*(sum[i]-sum[j]); } deque<int> q; int main() { scanf("%d", &T); while(T--) { scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i], &p[i]), sum[i] = sum[i-1]+a[i]; for (int i = n-1; i >= 1; --i) p[i] = min(p[i], p[i+1]); while(!q.empty()) q.pop_back(); q.push_back(0); for (int i = 1; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, p[i])) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j] + (sum[i]-sum[j]+10)*p[i]; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, i)) ; else { q.push_back(b); break; } } q.push_back(i); } printf("%lld\n", dp[n]); } return 0; }
例题3:HDU - 2993
思路:论文题,维护递增的斜率,居然卡读入,没意思
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 1e5 + 10; int n, k, a[N], q[N], head, tail; double sum[N]; const int BUF = 25000000; char Buf[BUF],*buf=Buf; inline void read(int &a) { for(a=0;*buf<48;buf++); while(*buf>47) a=a*10+*buf++-48; } int main() { int tot = fread(Buf, 1, BUF, stdin); while(true) { if(buf-Buf+1 >= tot) break; read(n), read(k); for (int i = 1; i <= n; ++i) read(a[i]), sum[i] = sum[i-1]+a[i]; head = tail = 0; q[tail++] = 0; double ans = 0; for (int i = k; i <= n; ++i) { while(head+1 < tail) { int a = q[head]; head++; int b = q[head]; if((sum[i]-sum[a])*(i-b) < (sum[i]-sum[b])*(i-a)) ; else { q[--head] = a; break; } } int x = q[head]; ans = max(ans, (sum[i]-sum[x])/(i-x)); x = i-k+1; while(head+1 < tail) { int b = q[tail-1]; --tail; int a = q[tail-1]; if((sum[x]-sum[b])*(x-a) < (sum[x]-sum[a])*(x-b)); else { q[tail++] = b; break; } } q[tail++] = x; } printf("%.2f\n", ans); } return 0; }
例题4:UVALive - 5097
思路:去重后发现按宽度排序后,高度递减
那么维护递增斜率:g[j, k] = (dp[j] - dp[k]) / (h[k] - h[j])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 5e4 + 10; pii a[N]; vector<pii> vc; int n, k, h[N], w[N]; LL dp[105][N]; deque<int> q[105]; bool g(int id, int k, int j, LL C) { return (dp[id][j]-dp[id][k]) <= C*(h[k+1]-h[j+1]); } bool gg(int id, int k, int j, int i) { return (dp[id][i]-dp[id][j])*(h[k+1]-h[j+1]) <= (dp[id][j]-dp[id][k])*(h[j+1]-h[i+1]); } int main() { while(~scanf("%d %d", &n, &k)) { for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i].fi, &a[i].se); sort(a+1, a+1+n); vc.clear(); for (int i = n; i >= 1; --i) if(i == n || a[i].se > vc.back().se) vc.pb(a[i]); reverse(vc.begin(), vc.end()); n = vc.size(); for (int i = 0; i < n; ++i) w[i+1] = vc[i].fi, h[i+1] = vc[i].se; for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back(); q[0].push_back(0); for (int i = 0; i <= k; ++i) for (int j = 0; j <= n; ++j) dp[i][j] = 0x3f3f3f3f3f3f3f3f; dp[0][0] = 0; for (int i = 1; i <= n; ++i) { for (int j = 0; j < k; ++j) { while(q[j].size() >= 2) { int a = q[j].front(); q[j].pop_front(); int b = q[j].front(); if(g(j, a, b, w[i])) ; else { q[j].push_front(a); break; } } int x = q[j].front(); dp[j+1][i] = min(dp[j+1][i], dp[j][x] + w[i]*1LL*h[x+1]); while(q[j].size() >= 2) { int b = q[j].back(); q[j].pop_back(); int a = q[j].back(); if(gg(j, a, b, i)) ; else { q[j].push_back(b); break; } } q[j].push_back(i); } } LL ans = 1LL<<60; for (int i = 1; i <= k; ++i) ans = min(ans, dp[i][n]); printf("%lld\n", ans); } return 0; }
例题5:HDU - 3045
思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 4e5 + 5; int n, k; LL a[N], sum[N], dp[N]; bool g(int k, int j, LL C) { return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k <= C*(a[j+1]-a[k+1]); } bool gg(int k, int j, int i) { return (dp[i]-dp[j]+sum[j]-sum[i]+a[i+1]*i-a[j+1]*j)*(a[j+1]-a[k+1]) <= (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k)*(a[i+1]-a[j+1]); } deque<int> q; int main() { while(~scanf("%d %d", &n, &k)) { for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]); sort(a+1, a+1+n); for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+a[i]; while(!q.empty()) q.pop_back(); dp[0] = 0; q.push_back(0); for (int i = k; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, i)) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j]+sum[i]-sum[j]-a[j+1]*1LL*(i-j); if(i-k+1 >= k) { while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, i-k+1)) ; else { q.push_back(b); break; } } q.push_back(i-k+1); } } printf("%lld\n", dp[n]); } return 0; }
例题6:POJ - 1180
思路:要单独算s的影响,因为有s的存在时间就不好算前缀和了,对于每次新的开始s的影响是s*suf[i]
那么就是维护递增斜率:g[j, k] = (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) / (sum[j] - sum[k])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<deque> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 1e4 + 5; int T[N], F[N], n, s; LL sum[N], suf[N], dp[N]; bool g(int k, int j, LL C) { return dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) <= C*(sum[j]-sum[k]); } bool gg(int k, int j, int i) { return (dp[i]-dp[j]+s*(suf[i+1]-suf[j+1]))*(sum[j]-sum[k]) <= (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]))*(sum[i]-sum[j]); } deque<int> q; int main() { scanf("%d", &n); scanf("%d", &s); for (int i = 1; i <= n; ++i) scanf("%d %d", &T[i], &F[i]); for (int i = 1; i <= n; ++i) sum[i] = sum[i-1] + F[i], T[i]+=T[i-1]; for (int i = n; i >= 1; --i) suf[i] = suf[i+1] + F[i]; q.push_back(0); for (int i = 1; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, T[i])) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j] + T[i]*(sum[i]-sum[j])+s*suf[j+1]; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, i)) ; else { q.push_back(b); break; } } q.push_back(i); } printf("%lld\n", dp[n]); return 0; }
例题7:POJ - 2018
思路:同HDU-2993
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<deque> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 1e5 + 10; int n, f, a[N]; LL sum[N]; deque<int> q; bool g(int k, int j, int i) { return (sum[j]-sum[k])*(i-j) <= (sum[i]-sum[j])*(j-k); } int main() { scanf("%d %d", &n, &f); for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i]; q.push_back(0); LL ans = 0; for (int i = f; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, i)) ; else { q.push_front(a); break; } } int x = q.front(); ans = max(ans, (sum[i]-sum[x])*1000/(i-x)); x = i+1-f; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(!g(a, b, x)) ; else { q.push_back(b); break; } } q.push_back(x); } printf("%lld\n", ans); return 0; }
例题8:POJ - 3709
思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<deque> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 5e5 + 10; int a[N], n, k, T; LL sum[N], dp[N]; LL dw(int k, int j) { return a[j+1]-a[k+1]; } LL up(int k, int j) { return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*1LL*j-a[k+1]*1LL*k; } LL g(int k, int j, LL C) { return up(k, j) <= C*dw(k, j); } LL gg(int k, int j, int i) { return up(j, i)*dw(k, j) <= up(k, j)*dw(j, i); } deque<int> q; int main() { scanf("%d", &T); while(T--) { scanf("%d %d", &n, &k); for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i]; while(!q.empty()) q.pop_back(); q.push_back(0); for (int i = k; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, i)); else { q.push_front(a); break; } } int x = q.front(); dp[i] = dp[x]+sum[i]-sum[x]-a[x+1]*1LL*(i-x); x = i-k+1; if(x >= k) { while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, x)) ; else { q.push_back(b); break; } } q.push_back(x); } } printf("%lld\n", dp[n]); } return 0; }
例题9:UVA - 12594
思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]-k*s[k]+j*s[j]) / (j-k),其中sum[i] = ∑(j-pos)*pos, s[i] = ∑pos
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 2e4 + 10, M = 505; const LL INF = 0x3f3f3f3f3f3f3f3f; int T, n, k, pos[26]; LL sum[N], s[N], dp[M][N]; char nm[N], pn[N]; deque<int> q[M]; LL up(int id, int k, int j) { return dp[id][j]-dp[id][k]+sum[k]-sum[j]-k*s[k]+j*s[j]; } LL dw(int k, int j) { return j-k; } bool g(int id, int k, int j, LL C) { return up(id, k, j) <= C*dw(k, j); } bool gg(int id, int k, int j, int i) { return up(id, j, i)*dw(k, j) <= up(id, k, j)*dw(j, i); } int main() { scanf("%d", &T); for(int cs = 1; cs <= T; ++cs) { scanf("%s %d", pn, &k); scanf("%s", nm+1); n = strlen(nm+1); for (int i = 0; i < 26; ++i) pos[pn[i]-'a'] = i; for (int i = 1; i <= n; ++i) s[i] = s[i-1]+pos[nm[i]-'a']; for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+(i-1-pos[nm[i]-'a'])*1LL*pos[nm[i]-'a']; for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back(); dp[0][0] = 0; q[0].push_back(0); for (int i = 1; i <= n; ++i) { for (int j = 0; j < k; ++j) { while(q[j].size() >= 2) { int a = q[j].front(); q[j].pop_front(); int b = q[j].front(); if(g(j, a, b, s[i])) ; else { q[j].push_front(a); break; } } int x = q[j].front(); dp[j+1][i] = dp[j][x]+sum[i]-sum[x]-x*(s[i]-s[x]); } for (int j = 1; j <= k; ++j) { while(q[j].size() >= 2) { int b = q[j].back(); q[j].pop_back(); int a = q[j].back(); if(gg(j, a, b, i)) ; else { q[j].push_back(b); break; } } q[j].push_back(i); } } printf("Case %d: %lld\n", cs, dp[k][n]); } return 0; }
例题10:luoguo P4983 忘情
思路:wqs二分+斜率优化dp
#pragma GCC optimize(3) #include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1) #define LL long long #define ULL unsigned LL #define pb push_back #define pii pair<int, int> #define tiii tuple<int, int, int> #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); const int N = 1e5 + 5; LL dp[N], cnt[N], s[N]; int a[N], n, m; inline LL f(int x) { return dp[x] + s[x]*s[x] - 2*s[x]; } inline LL up(int a, int b) { return f(b) - f(a); } inline LL dw(int a, int b) { return s[b] - s[a]; } inline pair<LL, int> check(LL k) { deque<int> q; q.push_back(0); for (int i = 1; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(up(a, b) < 2*s[i]*dw(a, b)) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j] + (s[i] - s[j] + 1) * (s[i] - s[j] + 1) - k; cnt[i] = cnt[j] + 1; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(up(b, i) * dw(a, b) <= up(a, b) * dw(b, i)) ; else { q.push_back(b); break; } } q.push_back(i); } return make_pair(dp[n], cnt[n]); } int main() { scanf("%d %d", &n, &m); for (int i = 1; i <= n; ++i) { scanf("%d", &a[i]); s[i] = s[i-1] + a[i]; } LL l = -1e16, r = 1e16, mid = l+r+1 >> 1; while(l < r) { if(check(mid).se > m) r = mid - 1; else l = mid; mid = l+r+1 >> 1; } pair<LL, int> p = check(mid); printf("%lld\n", p.fi + m * mid); return 0; }