[CSAcademy]Squared Ends
[CSAcademy]Squared Ends
题目大意:
给你一个长度为\(n(n\le10^4)\)的数列\(\{A_i\}(A_i\le10^6)\)。定义区间\(A_{[l,r]}\)的代价为\((A_l-A_r)^2\)。求将\(\{A_i\}\)划分成\(k(k\le100)\)个区间的最小代价。
思路:
不难想到一种动态规划,用\(f[i][j]\)表示已经划分了\(i\)个区间,结尾是\(j\)的最小代价。转移方程为:
\[f[i][j]=\min\{f[i-1][k-1]+(A_j-A_k)^2\}
\]
时间复杂度是\(\mathcal O(n^2k)\)。
变形得:
\[f[i][j]=A_j^2+\min\{-2A_jA_k+f[i-1][k-1]+A_k^2\}
\]
其中\(\min\)中的东西可以看做是关于\(A_j\)的一次函数。而寻找\(\min\)值的过程就相当于在一堆一次函数中找最小值,用李超树维护凸壳即可。
时间复杂度\(\mathcal O(nk\log\operatorname{range}(A_i))\)。
源代码:
#include<cstdio>
#include<cctype>
#include<climits>
#include<algorithm>
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'0';
while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
return x;
}
const int N=1e4+1,K=101,LIM=1e6;
typedef long long int64;
class SegmentTree {
#define _left <<1
#define _right <<1|1
#define mid ((b+e)>>1)
private:
struct Node {
int64 a,b,time;
};
Node node[LIM<<2];
public:
void reset(const int &p,const int &b,const int &e) {
node[p]=(Node){0,0};
if(b==e) return;
reset(p _left,b,mid);
reset(p _right,mid+1,e);
}
void insert(const int &p,const int &b,const int &e,const int64 &i,const int64 &j,const int &t) {
if(node[p].time!=t) {
node[p].a=i;
node[p].b=j;
node[p].time=t;
return;
}
const int64 lval1=node[p].a*b+node[p].b;
const int64 rval1=node[p].a*e+node[p].b;
const int64 lval2=i*b+j,rval2=i*e+j;
if(lval1<=lval2&&rval1<=rval2) return;
if(lval2<=lval1&&rval2<=rval1) {
node[p].a=i;
node[p].b=j;
return;
}
if(b==e) return;
const long double c=1.*(node[p].b-j)/(i-node[p].a);
if(lval1<=lval2&&c<=mid) {
insert(p _left,b,mid,node[p].a,node[p].b,t);
node[p].a=i;
node[p].b=j;
return;
}
if(lval1<=lval2&&c>=mid) {
insert(p _right,mid+1,e,i,j,t);
return;
}
if(lval1>=lval2&&c<=mid) {
insert(p _left,b,mid,i,j,t);
return;
}
if(lval1>=lval2&&c>=mid) {
insert(p _right,mid+1,e,node[p].a,node[p].b,t);
node[p].a=i;
node[p].b=j;
return;
}
}
int64 query(const int &p,const int &b,const int &e,const int &x,const int &t) const {
if(node[p].time!=t) return LLONG_MAX;
int64 ret=node[p].a*x+node[p].b;
if(b==e) return ret;
if(x<=mid) ret=std::min(ret,query(p _left,b,mid,x,t));
if(x>mid) ret=std::min(ret,query(p _right,mid+1,e,x,t));
return ret;
}
#undef _left
#undef _right
#undef mid
};
SegmentTree t;
int a[N];
int64 f[K][N];
int main() {
const int n=getint(),k=getint();
for(register int i=1;i<=n;i++) {
a[i]=getint();
f[0][i]=INT_MAX*500ll;
}
t.reset(1,1,n);
for(register int i=1;i<=k;i++) {
for(register int j=i;j<=n+i-k;j++) {
t.insert(1,1,LIM,-2*a[j],f[i-1][j-1]+(int64)a[j]*a[j],i);
f[i][j]=(int64)a[j]*a[j]+t.query(1,1,LIM,a[j],i);
}
}
printf("%lld\n",f[k][n]);
return 0;
}