【题解】【THUSC 2016】成绩单 LOJ 2292 区间dp

Prelude

快THUWC了,所以补一下以前的题。
真的是一道神题啊,网上的题解没几篇,而且还都看不懂,我做了一天才做出来。

传送到LOJ:(>人<;)


Solution

直接切入正题。
我们考虑区间dp,第一件事是离散化。
然后用\(g(i,j)\)表示消除完闭区间\([i,j]\)的最小费用。
然后呢?怎么转移?exm???
这时候会有一个非常自然的想法。
计算\(g(i,j)\)的时候,我们枚举两个数\(l,r\),然后保留下值在闭区间\([l,r]\)之内的所有数,先消除掉其他的数字,就只剩\([l,r]\)之内的数字了,再一次性消除掉她们。
时间复杂度\(O(n^5)\),但是显然是错的。
错在哪里呢?大概是错在下面这种情况,我懒得构造具体的反例了。
对于一组数字\(abcabca\),我们可以先消除掉中间的\(a\),再消除掉\(bcbc\),最后再消除掉\(aa\),在我们的dp里面似乎并没有考虑到这种情况。
因为\(aa\)是最后消除掉的,因此如果我们选择保留\(a\)的话,会保留下来所有的\(a\)
我们太仁慈了,保留下来了\([l,r]\)之间的所有的数字,其实不一定要保留所有数字。
怎么办呢?
脑洞大开!
我们用\(f(i, j, l, r)\)表示,消除完在闭区间\([i,j]\)之内的,除了值在\([l,r]\)之间的所有数字。
注意,在\([l,r]\)之间的数字,可以消除,也可以不消除。
然后显然有这个东西:

$\Large g(i, j) = \min f(i, j, l, r)$
实际上就是枚举$l,r$嘛。 然后我们考虑$f(i, j, l, r)$如何转移。 当闭区间$[i,j]$内元素全部在$[l,r]$之间的时候,显然$f(i, j, l, r)=0$。 当闭区间$[i,j]$内元素全部不在$[l,r]$之间的时候,显然$f(i, j, l, r)=g(i, j)$。 $f(i, j, l, r)=g(i, j)$似乎构成了循环依赖? 那么,我们枚举$l,r$的时候,必须保证区间$[i,j]$内存在至少一个数字在$[l,r]$内,这样就不会有循环依赖了。 解决了$f(i, j, l, r)$的边界问题,接下来看如何转移。 像普通的区间dp一样,我们枚举区间的分裂点$k$,然后把区间$[i,j]$分裂成$[i,k]$和$[k+1,j]$两部分,递归做下去。 有式子:
$\Large f(i, j, l, r) = \min f(i, k, l, r) + f(k+1, j, l, r)$
感受一下,感觉似乎是能处理各种情况的? 但是实际上和刚刚的做法没有任何区别。 因为对于状态$f(i, j, l, r)$,我们仍然保留了$[l,r]$之间的所有数字,仍然是那么的仁慈。 我们需要加一种暴力斩掉所有数字的情况。 有式子:
$\Large f(i, j, l, r) = \min g(i, k) + f(k+1, j, l, r)$
仔细感受一下,这两个$f(i, j, l, r)$的转移式结合起来之后,就可以处理掉所有情况了! 时间复杂度仍然是$O(n^5)$。 实现采用记忆化搜索,效果棒棒哒~ 真是一道神题啊。。。

Code

#include <cstring>
#include <algorithm>
#include <cstdio>
#include <iostream>

using namespace std;
const int N = 52;
const int W = 1010;
const int INF = 0x3f3f3f3f;
int _w;

int bmin( int &a, int b ) {
	return a = b < a ? b : a;
}

int n, a, b, w[N];
int vis[W], num[N], m;
int f[N][N][N][N], g[N][N];
int F( int, int, int, int );
int G( int, int );

void discrete() {
	for( int i = 1; i <= n; ++i )
		vis[w[i]] = 1;
	m = 1;
	for( int i = 1; i < W; ++i )
		if( vis[i] )
			vis[i] = m, num[m++] = i;
	--m;
	for( int i = 1; i <= n; ++i )
		w[i] = vis[w[i]];
}

bool contain( int i, int j, int l, int r ) {
	for( int p = i; p <= j; ++p )
		if( w[p] >= l && w[p] <= r )
			return true;
	return false;
}

bool all( int i, int j, int l, int r ) {
	for( int p = i; p <= j; ++p )
		if( w[p] < l || w[p] > r )
			return false;
	return true;
}

int F( int i, int j, int l, int r ) {
	int &now = f[i][j][l][r];
	if( now != -1 ) return now;
	if( all(i, j, l, r) ) return now = 0;
	if( !contain(i, j, l, r) ) return now = G(i, j);
	now = INF;
	for( int k = i; k < j; ++k ) {
		bmin( now, F(i, k, l, r) + F(k+1, j, l, r) );
		bmin( now, G(i, k) + F(k+1, j, l, r) );
	}
	// printf( "f[%d][%d][%d][%d] = %d\n", i, j, l, r, now );
	return now;
}

int G( int i, int j ) {
	int &now = g[i][j];
	if( now != -1 ) return now;
	now = INF;
	for( int l = 1; l <= m; ++l )
		for( int r = l; r <= m; ++r )
			if( contain(i, j, l, r) ) {
				int u = num[l], v = num[r];
				bmin( now, F(i, j, l, r) + (v-u)*(v-u)*b + a );
			}
	// printf( "g[%d][%d] = %d\n", i, j, now );
	return now;
}

int main() {
	cin >> n >> a >> b;
	for( int i = 1; i <= n; ++i )
		cin >> w[i];
	discrete();
	memset(f, -1, sizeof f);
	memset(g, -1, sizeof g);
	printf( "%d\n", G(1, n) );
	return 0;
}
posted @ 2018-01-19 12:43  mlystdcall  阅读(965)  评论(0编辑  收藏  举报