P5336 [THUSC2016]成绩单 题解
原题链接:P5336 [THUSC2016]成绩单
在机房大佬的帮助下终于写出了这题。再一看题解区,没有类似做法,故蒟蒻来分享一下。
题意
给定序列\(W_{1\cdots n}\)与常数\(A\)、\(B\),每次从中选取连续的一段,定义每一段代价为\(V_k= A+B\times(x_{\max}-x_{\min})^2\),\(x\in[l,r]\)。每次进行选取后,剩下的部分自动变为连续(\(x\in [1,l)\cup(r,n]\)的部分变为连续)。
求:\(\min\{ \sum V_k\}\).
\(n\leq50\).
题解
此题是区间dp。
设\(ans(l,r)\)表示将\([l,r]\)区间全部取出的最小代价,但是发现仅有这两维无法转移。
因为每一段的代价仅与\(x_{\max}\)和\(x_{\min}\)有关,所以可以考虑:再开两维为\(dp(l,r,a,b)\),表示\([l,r]\)区间内,经过多次选取后,剩余的极值分别为,\(a\)、\(b\)的代价。
那么如果要把\(dp(l,r,a,b)\)中权值为\([a,b]\)的这一段全部选空,其代价为\(A+B\times(b-a)^2\)。
则可以得到\(ans\)的转移式:\(ans[l][r] = \min\{dp[l][r][a][b]+A+B\times(b-a)^2\}\),\(1\le a < b \le W_{\max}\).
接下来考虑如何转移\(dp\).
按照区间dp的思路,我们需要在\([l,r]\)中枚举断点\(k\)。
则有如下转移:
\(dp(l,r,a,b)=\min\{ans(l,k) + dp(k+1,r,a,b)\}\).(左半区间全部选空,右半区间剩余\([a,b]\))
\(dp(l,r,a,b)=\min\{dp(l,k,a,b)+ans(k+1,r)\}\).(左半区间剩余\([a,b]\),右半区间全部选空)
\(dp(l,r,a,b)=\min\{dp(l,k,a,b) + dp(k+1,r,a,b)\}\).(左右区间均都剩余\([a,b]\))
我们只需要对于每一个\([l,r]\),枚举其\(a,b\)即可求出某一段区间的所有答案。
关于\(dp\)和\(ans\)的初值,则显然有:\(ans(i,i) = A\),\(dp(i,i,1\cdots W_i,W_i\cdots W_{\max}) = 0\)。
最后,因为\(W_i\le1000\),离散化一下即可。
复杂度:状态数为\(n^4\),枚举断点为\(O(n)\),总复杂度\(O(n^5)\),可以通过此题。
代码
自认为码风还可以QwQ
把快读和宏定义去掉了。
int n,m,A,B;
int w[N],dsc[N];
ll ans[N][N],dp[N][N][N][N];
void dfs(int l,int r){
if(l > r)return;
if(ans[l][r] != ans[0][0])return;//这里应该是判断ans[l][r] == INF的,但是我不知道long long 下的0x3f3f...是多少qwq
for(int a = 1 ; a <= m ; a ++)
for(int b = a ; b <= m ; b ++){
for(int k = l ; k < r ; k ++)
dfs(l,k),dfs(k+1,r),
dp[l][r][a][b] = min(dp[l][r][a][b],ans[l][k] + dp[k+1][r][a][b]),
dp[l][r][a][b] = min(dp[l][r][a][b],dp[l][k][a][b] + ans[k+1][r]),
dp[l][r][a][b] = min(dp[l][r][a][b],dp[l][k][a][b] + dp[k+1][r][a][b]),
ans[l][r] = min(ans[l][r],dp[l][r][a][b] + A + B *(dsc[b]-dsc[a])*(dsc[b]-dsc[a]));
}
}
signed main(){
memset(ans,0x3f,sizeof(ans)),memset(dp,0x3f,sizeof(dp));
n = read() , A = read() , B = read();
for(int i = 1 ; i <= n ; i ++)
dsc[i] = w[i] = read();
sort(dsc+1,dsc+n+1);
m = unique(dsc+1,dsc+n+1)-dsc-1;
for(int i = 1 ; i <= n ; i ++)
w[i] = lower_bound(dsc+1,dsc+m+1,w[i]) - dsc;//离散化
for(int i = 1 ; i <= n ; i ++){//初值
ans[i][i] = A;
for(int j = 1 ; j <= w[i] ; j ++)
for(int k = w[i] ; k <= m ; k ++)
dp[i][i][j][k] = 0;
}
dfs(1,n);//这里我写成了记搜,感觉记搜更好写
printf("%lld",ans[1][n]);
return 0;
}