P4983 忘情
首先看一下每段的值:
\[\frac{\left((\sum\limits_{i=1}^{n}x_i×\bar x)+\bar x\right)^2}{\bar x^2}
\]
\[\frac{(\sum\limits_{i=1}^{n}x_i×\bar x)^2+2\bar x(\sum\limits_{i=1}^{n}x_i×\bar x)+\bar x^2}{\bar x^2}
\]
\[\sum\limits_{i=1}^{n}x_i^2+2\sum\limits_{i=1}^{n}x_i+1
\]
\[(\sum_{i=1}^nx_i+1)^2
\]
由 \((a+b)^2 \ge a^2+b^2\) 可以显然发现:分的段数越多,答案越小,并且每次分段减少的贡献单调不升。
说白了它就是一个凸函数,考虑 wqs 二分。
设当前二分的权重为 \(k\) , 那么有:
\[dp_i=\min_{0\le j < i} \{dp_j+(\text{Sum}_i-\text{Sum}_{j}+1)^2+k\}
\]
\[dp_i=\min_{0\le j < i} \{dp_j + \text{Sum}_j^2-2\text{Sum}_j(\text{Sum}_i+1)\}+(\text{Sum}_i+1)^2+k
\]
设有两个决策点 \(j,k (j < k)\) , 且 \(j\) 优于 \(k\)。
\[dp_j + \text{Sum}_j^2-2\text{Sum}_j(\text{Sum}_i+1) < dp_k + \text{Sum}_k^2-2\text{Sum}_k(\text{Sum}_i+1)
\]
\[(dp_j + \text{Sum}_j^2)-(dp_k+\text{Sum}_k^2)< 2(\text{Sum}_i+1)(\text{Sum}_j-\text{Sum}_k)
\]
可以斜率优化,每次 dp 的复杂度为 \(\mathcal O(n)\)。
总的时间复杂度为 \(\mathcal O(n\log w)\)。
还有,这道题卡精度。
#include <cstdio>
#include <cstring>
#define LL long long
#define Inf 1e18
#define eps 1e-10
const int MAXN = 100000;
int n , m , x[ MAXN + 5 ]; LL Sum[ MAXN + 5 ];
int cnt[ MAXN + 5 ]; LL dp[ MAXN + 5 ];
int Que[ MAXN + 5 ] , Head , Tail;
LL fx( int x ) { return Sum[ x ]; }
LL fy( int x ) { return dp[ x ] + Sum[ x ] * Sum[ x ]; }
LL dx( int x , int y ) { return fx( x ) - fx( y ); }
LL dy( int x , int y ) { return fy( x ) - fy( y ); }
double Slope( int x , int y ) { return dy( x , y ) * 1.0 / dx( x , y ); }
LL check( LL k ) {
Head = 1 , Tail = 0; Que[ ++ Tail ] = 0;
for( int i = 1 ; i <= n ; i ++ ) {
for( ; Head + 1 <= Tail && Slope( Que[ Head ] , Que[ Head + 1 ] ) + eps < 2 * ( Sum[ i ] + 1 ) ; Head ++ );
dp[ i ] = dp[ Que[ Head ] ] + ( Sum[ i ] - Sum[ Que[ Head ] ] + 1 ) * ( Sum[ i ] - Sum[ Que[ Head ] ] + 1 ) + k; cnt[ i ] = cnt[ Que[ Head ] ] + 1;
for( ; Head + 1 <= Tail && Slope( Que[ Tail - 1 ] , Que[ Tail ] ) + eps > Slope( Que[ Tail ] , i ) ; Tail -- );
Que[ ++ Tail ] = i;
}
return cnt[ n ] > m;
}
int main( ) {
scanf("%d %d",&n,&m);
for( int i = 1 ; i <= n ; i ++ ) scanf("%d",&x[ i ]) , Sum[ i ] = Sum[ i - 1 ] + x[ i ];
LL l = 0 , r = Inf , Ans;
for( ; l <= r ; ) {
LL Mid = ( l + r ) >> 1;
if( check( Mid ) ) l = Mid + 1;
else r = Mid - 1 , Ans = Mid;
}
check( Ans );
printf("%lld\n", dp[ n ] - Ans * m );
return 0;
}