石子合并-平行四边形优化

上一篇博客讲了石子合并的基本做法,n^3复杂度的dp,今天无意间看到这个优化方法,觉得有必要学习一下。

平行四边形优化是一种可以将三维DP复杂度降到n^2方的方法,但是并不是所有的dp都适用,需要满足一定条件,如下:

当决策代价函数w[i][j]满足w[ i ][ j ]+w[ i’ ][ j’ ]<=w[ i; ][ j ]+w[ i ][ j’ ](i<=i’<=j<=j’)时,称w满足四边形不等式.

当函数w[ i ][ j ]满足w[ i’ ][ j ]<=w[ i ][ j’ ] i<=i’<=j<=j’)时,称w关于区间包含关系单调.

如果满足以上两点,可利用四边形不等式推出最优决策s的单调函数性,从而减少每个状态的状态数,将算法的时间复杂度降低一维。

具体的实施是通过记录子区间的最优决策来减少当前的决策量.令:

s[ i ][ j ]=max{k | ma[ i ] [ j ] = m[ i ][ k-1 ] + m[ k ] [ j ] + w[ i ][ j ] }

即s[ i ] [ j ]就记录了合并第i到第j堆石子时的最优合并,记录是为了限制后面的循环范围,如上所说。

证明如下(转载 http://www.cnblogs.com/jiu0821/p/4493497.html):

m[i,j]表示动态规划的状态量。

m[i,j]有类似如下的状态转移方程:

m[i,j]=opt{m[i,k]+m[k,j]}(ikj)

如果对于任意的abcd,有m[a,c]+m[b,d]m[a,d]+m[b,c],那么m[i,j]满足四边形不等式。

以上是适用这种优化方法的必要条件

对于一道具体的题目,我们首先要证明它满足这个条件,一般来说用数学归纳法证明,根据题目的不同而不同。

通常的动态规划的复杂度是O(n3),我们可以优化到O(n2)

s[i,j]m[i,j]的决策量,即m[i,j]=m[i,s[i,j]]+m[s[i,j]+j]

我们可以证明,s[i,j-1]s[i,j]s[i+1,j]  (证明过程见下)

那么改变状态转移方程为:

m[i,j]=opt{m[i,k]+m[k,j]}      (s[i,j-1]ks[i+1,j])

复杂度分析:不难看出,复杂度决定于s的值,以求m[i,i+L]为例,

(s[2,L+1]-s[1,L])+(s[3,L+2]-s[2,L+1])…+(s[n-L+1,n]-s[n-L,n-1])=s[n-L+1,n]-s[1,L]n

所以总复杂度是O(n2)

s[i,j-1]s[i,j]s[i+1,j]的证明:

mk[i,j]=m[i,k]+m[k,j]s[i,j]=d

对于任意k<d,有mk[i,j]md[i,j](这里以m[i,j]=min{m[i,k]+m[k,j]}为例,max的类似),接下来只要证明mk[i+1,j]md[i+1,j],那么只有当s[i+1,j]s[i,j]时才有可能有ms[i+1,j][i+1,j]md[i+1,j]

(mk[i+1,j]-md[i+1,j]) - (mk[i,j]-md[i,j])

=(mk[i+1,j]+md[i,j]) - (md[i+1,j]+mk[i,j])

=(m[i+1,k]+m[k,j]+m[i,d]+m[d,j]) - (m[i+1,d]+m[d,j]+m[i,k]+m[k,j])

=(m[i+1,k]+m[i,d]) - (m[i+1,d]+m[i,k])

m满足四边形不等式,∴对于i<i+1k<dm[i+1,k]+m[i,d]m[i+1,d]+m[i,k]

(mk[i+1,j]-md[i+1,j])(mk[i,j]-md[i,j])0

s[i,j]s[i+1,j],同理可证s[i,j-1]s[i,j]

证毕

解决这类dp平行四边形优化问题的大概步骤是:

1.证明w满足四边形不等式,这里wm的附属量,如m[i,j]=opt{m[i,k]+m[k,j]+w[i,j]},此时大多要先证明w满足条件才能进一步证明m满足条件

2.证明m满足四边形不等式

3.证明s[i,j-1]s[i,j]s[i+1,j]

更新后的代码如下:

#include <cstdio>  
#include <queue>  
#include <cstring>  
#include <algorithm>  
using namespace std;   
  
int n,x;  
int sum[205];  
int dp[205][205];  
int s[205][205];  
  
int main()  
{  
    while(~scanf("%d",&n))  
    {  
        sum[0]=0;  
        memset(dp ,0,sizeof dp); 
        for(int i=1;i<=n;i++)  
        {  
            scanf("%d",&x);  
            sum[i]=sum[i-1]+x;  
            dp[i][i]=0;  
            s[i][i]=i;  
        }  
        for(int len=2;len<=n;len++)  
        for(int i=1;i<=n;i++)  
        {  
            int j=i+len-1;  
            if(j>n) continue;  
            for(int k=s[i][j-1];k<=s[i+1][j];k++)  
            {  
                if(dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]<dp[i][j])  
                {  
                    dp[i][j]=dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1];  
                    s[i][j]=k;  
                }  
            }  
        }  
        printf("%d\n",dp[1][n]);  
    }  
    return 0;  
}

 

posted @ 2018-05-02 01:59  fantastic123  阅读(431)  评论(0编辑  收藏  举报