DP的优化

一、数据结构优化DP

P3287 [SCOI2014] 方伯伯的玉米田

首先容易分析出一个性质:拔高玉米时,拔高 [i,n] 区间的玉米一定是最优的。然后就有了一个暴力DP:

f[i][j] 表示对于前 i 个玉米(第 i 个玉米保留),拔高 j 次最多能保留多少玉米。

状态转移方程:

if(a[l]>a[i]&&a[l]a[i]<=j):f[i][j]=max(f[l][j(a[l]a[i])])+1

if(a[l]<a[i]):f[i][j]=max(f[l][j])+1

然后你就可以拿到10分的好成绩

考虑优化。我们发现,状态转移需要求二维前缀的最大值,所以考虑二维树状数组优化。然而, a[l]a[i]<=j 导致不能直接转移,所以有必要重新考虑状态,使得转移没有条件限制。

dp[i][j][k] 表示对于前 i 个玉米(第 i 个玉米保留),拔高 j 次,第 i 个玉米高度为 k 时,最多能保留多少玉米。

状态转移方程:dp[i][j][a[i]+j]=max(dp[i1][p][q]),1<=p<=j,1<=q<=a[i]+j

然后就可以愉快地使用二维树状数组了。

code:

int ask(int x,int y){
	int re=0;
	for(int i=x;i;i-=i&(-i))
		for(int j=y;j;j-=j&(-j))
			re=max(re,dp[i][j]);
	return re;
}
void add(int x,int y,int z){
	for(int i=x;i<=k+1;i+=i&(-i))
		for(int j=y;j<=k+maxn;j+=j&(-j))
			dp[i][j]=max(dp[i][j],z);
}
int main(){
	scanf("%d%d",&n,&k);
	for(int i=1;i<=n;++i)
		scanf("%d",&a[i]),maxn=max(maxn,a[i]);
	for(int i=1;i<=n;++i){
		for(int j=0;j<=k;++j){
			int t=ask(j+1,a[i]+j)+1;//这里写j+1是避免j==0时树状数组死循环 
			now[j]=t;
			ans=max(ans,t);
		}
		for(int j=0;j<=k;++j)
			add(j+1,a[i]+j,now[j]);
	}
	/*for(int i=1;i<=n;++i)//注释掉的部分是暴力DP 
		for(int j=0;j<=k;++j){
			for(int l=0;l<i;++l){
				if(a[l]>a[i]&&a[l]-a[i]<=j)
					dp[i][j]=max(dp[i][j],dp[l][j-(a[l]-a[i])]+1);
				else if(a[l]<=a[i])
					dp[i][j]=max(dp[i][j],dp[l][j]+1); 
			}*/
	printf("%d\n",ans);
	return 0;
}

P6773命运

DP 神仙题。

DP[u][i] 表示以 u 为根的子树中,下端点在子树内,并且没有被满足的所有限制的最深的上端点的深度为 i 的方案数。

状态转移方程:

j=0dep[u]f[u][i]×f[v][j]+j=0if[u][i]×f[v][j]+j=0i1f[u][j]×f[v][i]f[u][i]

第一个和式是 (u,v)=1 的情况,第二,三个和式是 (u,v)=0 的情况。

然后可以用前缀和表示一下。设 g[u][i]=j=0if[u][j] ,那么状态转移方程可以表示为 f[u][i]×(g[v][dep[u]]+g[v][i])+g[u][i1]×f[v][i]

然后用线段树合并维护它就好了。

code:

void add(int x,int y){
    nxt[++tot]=head[x];head[x]=tot;ver[tot]=y;
}
void push_down(int u){//mul是乘法懒标记
    p[p[u].l].sum=p[p[u].l].sum*p[u].mul%mod;
    p[p[u].l].mul=p[p[u].l].mul*p[u].mul%mod;
    p[p[u].r].sum=p[p[u].r].sum*p[u].mul%mod;
    p[p[u].r].mul=p[p[u].r].mul*p[u].mul%mod;
    p[u].mul=1;
}
long long ask(int u,int ul,int ur,int pos){
    if(!u)
        return 0;
    if(ur<=pos)
        return p[u].sum;
    int mid=(ul+ur)>>1;long long re=0;
    push_down(u);
    if(mid>=pos)
        re=(re+ask(p[u].l,ul,mid,pos))%mod;
    if(mid<pos){
        re=(re+ask(p[u].l,ul,mid,pos))%mod;
        re=(re+ask(p[u].r,mid+1,ur,pos))%mod;
    }
    return re;
}
int merge(int u,int v,int ul,int ur,long long &s1,long long &s2){//s1:f[v][1...d[u]]+f[v][1...i]; s2:f[u][1...i-1]
    if(!u&&!v)
        return 0;
    if(!u||!v){
        if(!u){//f[u][i]<-s2*f[v][i]
            s1=(s1+p[v].sum)%mod;
            p[v].mul=p[v].mul*s2%mod;
            p[v].sum=p[v].sum*s2%mod;
            return v;
        }
        else{//f[u][i]<-f[u][i]*s1
            s2=(s2+p[u].sum)%mod;
            p[u].sum=p[u].sum*s1%mod;
            p[u].mul=p[u].mul*s1%mod;
            return u;
        }
    }
    if(ul==ur){//f[u][i]<-s1*f[u][i]+s2*f[v][i]
        long long tmp=p[u].sum,tmp2=p[v].sum;
        s1=(s1+tmp2)%mod;
        p[u].sum=((p[u].sum*s1%mod)+(p[v].sum*s2%mod))%mod;
        s2=(s2+tmp)%mod;
        return u;
    }
    push_down(u);push_down(v);
    int mid=(ul+ur)>>1;
    p[u].l=merge(p[u].l,p[v].l,ul,mid,s1,s2);
    p[u].r=merge(p[u].r,p[v].r,mid+1,ur,s1,s2);
    p[u].sum=(p[p[u].l].sum+p[p[u].r].sum)%mod;
    return u;
}
void add(int &u,int ul,int ur,int pos){//新开一个线段树
    if(!u) u=++cnt;
    p[u].sum=p[u].mul=1;
    if(ul==ur)
        return ;
    int mid=(ul+ur)>>1;
    if(mid>=pos)
        add(p[u].l,ul,mid,pos);
    if(mid<pos)
        add(p[u].r,mid+1,ur,pos);
}
void dfs(int x,int fa){
    d[x]=d[fa]+1;
    int maxx=0;
    for(int i=0;i<v[x].size();++i)
        maxx=max(maxx,d[v[x][i]]);//找到最深的未被满足的上端点的深度
    add(rt[x],0,n,maxx);
    for(int i=head[x];i;i=nxt[i]){
        if(ver[i]!=fa){
            dfs(ver[i],x);
            long long s=ask(rt[ver[i]],0,n,d[x]),ss=0;//s求出f[ver[i]][1...d[x]]
            rt[x]=merge(rt[x],rt[ver[i]],0,n,s,ss);
        }
    }
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;++i){
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    scanf("%d",&m);
    for(int i=1;i<=m;++i){
        scanf("%d%d",&x,&y);//anc[y]=x
        v[y].push_back(x);
    }
    dfs(1,0);
    printf("%lld\n",ask(rt[1],0,n,0));
    return 0;
}

二、单调队列优化DP

对于形如f[i]=max{f[j]+val(i,j)}(j[li,ri])DP,若满足以下条件:

①:j 的值域的上下界变化具有单调性;

②:val(i,j) 的每一项仅与ij中的一个有关;

则可以使用单调队列优化。

P2300 合并神犇

一个不太常规的单调队列优化DP。

f[i] 表示将 [1,i] 合并完所需的最小次数。

状态转移方程: f[i]=min{f[j]+ij1}(sum[i]sum[j]>=pre[j]) 。其中 pre[j] 表示合并完 [1,j] 以后的最后一个数的大小。

接下来考虑优化。

引理:f[j+1]<=f[j]+1

证明:假设原序列为 [1,j] ,现在新加入一个数 a[j+1] 。它要么单独作为一个数,即 f[j+1]=f[j] ;要么与前面的数合并在一起,即 f[j+1]=f[j]+1

观察状态转移方程,再结合引理,可以发现当 j 增加时, f[j]j1 的数值是单调不增的。所以我们要找的就是满足 (sum[i]sum[j]>=pre[j]) 的最大的 j

j 的约束条件进行变形: sum[j]+pre[j]<=sum[i] 。因为 sum[i] 是单调递增的,所以 sum[j]+pre[j] 越小,就越有可能在更多的转移中被选择。

因此我们可以维护一个 j 单调递增,且 sum[j]+pre[j] 递增的单调队列。每次新进一个元素 k ,就把队尾 sum[tail]+pre[tail]>=sum[k]+pre[k] 的元素弹掉;状态转移时,从队首不断弹出元素,直到找到最靠后且满足 sum[i]sum[head]>=pre[head] 的元素。

code:

l=1;r=0;
for(int i=1;i<=n;++i){
	while(l<=r&&pre[q[l]]+sum[q[l]]<=sum[i])
		++l;
	f[i]=f[q[l-1]]+i-q[l-1]-1;
	pre[i]=sum[i]-sum[q[l-1]];
	while(l<=r&&sum[q[r]]+pre[q[r]]>=sum[i]+pre[i])
		--r;
	q[++r]=i;
}
printf("%lld\n",f[n]);

三、斜率优化DP

在上文中,如果把条件“每一项都只和ij其中一个有关”去掉,那么就可以使用斜率优化DP来做。

具体地,将min/max函数去掉,并将状态转移方程改写成:

f[j]+val(j)=F(i)F(j)+f[i]val(i)

y=f[j]+val(j)k=F(i)x=F(j)b=f[i]val(i),那么该柿子就成了一次函数表达式y=kx+b

接下来我们以min函数为例。注意到 k 是固定的,所以为了最小化f[i],就要最小化截距。

对于任意三个决策点j1,j2,j3(j1<j2<j3),不难发现只有当j1,j2,j3三个决策点构成下凸壳(斜率单调递增)时,j2才有可能成为最优决策点。因此我们需要维护一个下凸壳。那么具体怎么维护呢?

①:当 x,k 都单调递增时,可以用单调队列维护凸包。查询时,可以直接从队首开始查询,如果当前不是最优决策点,就弹出;直到找到最优决策点为止。

②:当 x 单调递增,但 k 不递增时,仍然可以用单调队列维护。查询时,则需要在凸包上二分查找最优决策点。

③:如果二者都不单调递增,则需要用平衡树维护凸包(动态凸包),或者李超树。

三、四边形不等式优化DP

四边形不等式:

对于函数 w(a,b) ,如果满足对于任意的 a<b<c<d ,有w(a,d)+w(b,c)>=w(a,c)+w(b,d),则称该函数满足四边形不等式。

形象点说,就是“包含大于等于交叉”。

一维DP

前置知识(决策单调性的定义):对于形如f[i]=min0j<i{f[j]+val(i,j)}的状态转移方程,记p[i]为使f[i]取到最小值的j,即 f[i] 的决策点。如果当 i 递增时,p[i]单调不减,那么称 f[i] 具有决策单调性。

定理一:在上述状态转移方程中,如果val(i,j)满足四边形不等式,则 f[i] 具有决策单调性。

技巧:二分队列

首先,我们用三元组将决策点数组进行替换。具体而言,就是说如果f[l]f[r]的决策点都是j,那么我们将其用用三元组 (j,l,r) 表示。然后我们将三元组放入队列,每一个元素的 l 值是上一个元素的 r+1 。初始时队列只有一个元素,为(0,1,n)

对于每个i[1,n],执行这样的操作:

①:更新f[i]。设队首为(j0,l0,r0),如果r0<i,弹出队首,直到找到l0ir0的三元组为止,用它的j0更新f[i],并令l0=i

②:考虑f[i]能作为哪些f[i]的决策点。取出队尾,记为(j1,l1,r1)

(1):如果对于f[l1]来说,i是比j1更优的决策,说明当前元素代表的区间的决策点都是i,记pos=l1,并将队尾弹出。

(2):如果对于f[r1]来说,i是比j1更劣的决策,说明当前元素代表的区间存在决策点i,在[l1,r1]上二分查找,求出位置pos[l1,pos1]的决策点不变,[pos,r1]的决策点为i。再令r1=pos1

(3):将(i,pos,n)加入队尾。

P1912 [NOI2009] 诗人小G

f[i]为前i行的最小代价,有f[i]=min0j<i{f[j]+|sum[i]sum[j]+ij1len|p}

因为代价函数存在高次乘积,所以不宜使用单调队列或者斜率优化。

打表可知,代价函数满足四边形不等式,因此f满足决策单调性。然后运用二分队列即可。

code:

long long _get(long long i,long long l,long long r){
    long long ans;
    while(l<=r){
        int mid=(l+r)>>1;
        if(q[mid].l<=i&&q[mid].r>=i){
            ans=mid;
            break;
        }
        if(i>=q[mid].l)
            l=mid+1;
        else
            r=mid-1;
    }
    return q[ans].x;
}
long double val(long long i,long long j){
    long double ans=1,w=abs((long double)(sum[i]-sum[j]+i-j-1-len));
    for(int i=1;i<=p;++i)
        ans*=w;
    return ans+f[j];
}
void insert(long long i,long long &l,long long &r){
    int pos=-1;
    while(l<=r){//从队尾开始往前查找,决策点单调不增
        if(val(q[r].l,i)<=val(q[r].l,q[r].x))
            pos=q[r].l,--r;//该元素代表的区间的全部决策点都为i
        else{
            if(val(q[r].r,q[r].x)>val(q[r].r,i)){//该元素代表的区间存在决策点i
                int l2=q[r].l,r2=q[r].r;
                while(l2<r2){//二分查找,[q[r].l,l2-1]的决策点不变,[l2,n]的决策点为i
                    int mid=(l2+r2)>>1;
                    if(val(mid,i)>val(mid,q[r].x))
                        l2=mid+1;
                    else
                        r2=mid;
                }
                q[r].r=l2-1;
                pos=l2;
            }
            break;//该元素代表的区间的决策点都不是i,直接停止循环
        }
    }
    if(pos!=-1){
        q[++r].l=pos;
        q[r].r=n;
        q[r].x=i;
    }
}
void print(int now){
    if(!now)
        return ;
    print(pre[now]);
    for(int i=pre[now]+1;i<=now;++i){
        cout<<s[i];
        if(i!=now) printf(" ");
    }
    cout<<endl;
}
void work(){
    scanf("%lld%lld%lld",&n,&len,&p);
    for(int i=1;i<=n;++i){
        cin>>s[i];
        sum[i]=sum[i-1]+s[i].size();
    }
    q[l=r=1].l=1;q[1].r=n;q[1].x=0;
    for(int i=1;i<=n;++i){
        long long j=_get(i,l,r);
        f[i]=val(i,j);
        pre[i]=j;
        while(l<=r&&q[l].r<=i)
            ++l;
        q[l].l=i+1;
        insert(i,l,r);
    }
    if(f[n]>1e18)
        puts("Too hard to arrange");
    else{
        cout<<(long long)f[n]<<endl;
        print(n);
    }
    for(int i=1;i<=20;++i)
        printf("-");
    printf("\n");
}
int main(){
    scanf("%lld",&t);
    while(t--)
        work();
    return 0;
}

二维DP

定理二:在形如 f[i][j]=minik<j{f[i][k]+f[k+1][j]+val(i,j)} 的状态转移方程中,如果val(i,j)满足:

①:四边形不等式;

②:区间包含单调性,即对于任意的 llrr ,有val(l,r)val(l,r)

那么:

①:f也满足四边形不等式,

②:对于 f 的决策点 p ,有 p[i][j1]p[i][j]p[i+1][j](i<j)

因此对于 f[l][r] ,在枚举决策点的时候,只需要在[p[l][r1],p[l+1][r]]范围内枚举即可。复杂度由O(n3)降到O(n2)

P4767 邮局

模版题。首先将村庄按照坐标从小到大排序。设 f[j][i] 表示前 i 个村庄,放了 j 的邮局的最小代价,有 f[j][i]=min{f[j1][k]+val(k+1,i)} ,其中 val(k+1,i) 表示 [k+1,i] 的村庄用一个邮局的最小代价。打表可知,代价函数满足四边形不等式和区间包含点调性。

如何计算val(i,j)?根据初中学过的绝对值的知识,我们可以知道邮局建在村庄坐标的中位数位置。然后预处理一下前缀和,则 val(i,j) 可以 O(1) 算出。

code:

int w(int l,int r){
    int mid=(l+r+1)>>1;
    return a[mid]*(mid-l+1)-(sum[mid]-sum[l-1])+(sum[r]-sum[mid])-a[mid]*(r-mid);
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)
        scanf("%d",&a[i]);
    sort(a+1,a+n+1);
    for(int i=1;i<=n;++i)
        sum[i]=sum[i-1]+a[i];
    memset(f,0x3f,sizeof(f));
    f[0][0]=0;
    for(int j=1;j<=m;++j){
        opt[n+1][j]=n;
        for(int i=n;i>=1;--i){
            int pos=0;
            for(int k=opt[i][j-1];k<=opt[i+1][j];++k){
                if(f[k][j-1]+w(k+1,i)<f[i][j]){
                    f[i][j]=f[k][j-1]+w(k+1,i);
                    pos=k;
                }
            }
            opt[i][j]=pos;
        }
    }
    printf("%d\n",f[n][m]);
    return 0;
}

技巧:分治法处理 val 函数难以计算的情况

如果 val 函数难以计算,但该函数能快速地从[l,r]扩展到[l±1,r±1],那么此时可以运用分治法。

solve(l,r,l2,r2,j) 表示用 f[l2...r2][j1] 来更新 f[mid][j](mid=l+r2)

每次更新时,运用双指针思想暴力移动 val(i,j)i,j,每移动一次,就计算新加入或新删除的数的贡献。

找到 f[mid][j] 的决策点 p 并更新完 f[mid][j] 以后,递归地处理 solve(l,mid1,l2,p,j),solve(mid+1,r,p,r2,j)

P5574 [CmdOI2019] 任务分配问题

f[j][i] 表示前 i 个任务,用了 j 个CPU的最小代价,有 f[j][i]=min{f[j1][k]+val(k+1,i)} ,其中 val(k+1,i) 表示 [k+1,i] 的顺序对的个数。打表可知,代价函数满足四边形不等式和区间包含点调性。

然后就可以直接运用上述技巧了。

void add(int x,int w){
    while(x<=n)
        c[x]+=w,x+=x&(-x);
}
int ask(int x){
    int re=0;
    while(x)
        re+=c[x],x-=x&(-x);
    return re;
}
void update(int l,int r){
    while(tr<r)
        ++tr,sum+=ask(a[tr]-1),add(a[tr],1);
    while(tl>l)
        --tl,sum+=ask(n)-ask(a[tl]),add(a[tl],1);
    while(tr>r)
        sum-=ask(a[tr]-1),add(a[tr],-1),--tr;
    while(tl<l)
        sum-=ask(n)-ask(a[tl]),add(a[tl],-1),++tl;
}
void solve(int l,int r,int l2,int r2,int j){
    if(l>r)
        return ;
    int mid=(l+r)>>1,p=mid;
    for(int i=min(mid-1,r2);i>=l2;--i){
        update(i+1,mid);
        if(f[mid][j]>=f[i][j-1]+sum)
            f[mid][j]=f[i][j-1]+sum,p=i;
    }
    solve(mid+1,r,p,r2,j);solve(l,mid-1,l2,p,j);
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)
        scanf("%d",&a[i]);
    memset(f,0x3f,sizeof(f));
    f[0][0]=0;
    tl=1;tr=0;
    for(int i=1;i<=m;++i)
        solve(1,n,0,n-1,i);
    printf("%d\n",f[n][m]);
    return 0;
}

其他技巧

Array Beauty

2022.7.21拷逝题。

首先将原序列排序,不会对答案造成影响。

f[i][j][k] 表示前 i 个数(第 i 个数选上),一共选择了 j 个数,其中任意两数差的绝对值的最小值为 k 。然而这样设状态,非常难写,所以可以考虑将“恰好”转为“至少”,然后统计答案时差分即可。

状态转移方程: f[i][j][k]=f[p][j1][k](|a[i]a[p]|>=k)

其中, k 这一维在转移中没有任何作用,可以省掉。

然而,这样写时间复杂度为 O(n2kv) ,爆炸,所以考虑优化。

设所有数的值域为 [1,v] ,选择了 k 个数。容易发现,其中任意两个数的差的绝对值的最小值最多是 v/k 。所以枚举任意两数的差的绝对值的最小值时只需枚举到 v/k 。时间复杂度为 O(n2k×v/k)=O(n2v) ,仍然过不了。

继续优化。如果我们先枚举 j ,再枚举 i ,可以发现 a[i]a[p] 是单调递增的。所以可以用双指针思想维护 ip 。时间复杂度为 O(nv) ,可以通过。

code:

void work(int v){
	f[0][0]=1;a[0]=-1e9;
	for(int i=1;i<=k;++i){
		int p=0,sum=0;
		for(int j=i;j<=n;++j){
			while(p<j&&a[j]-a[p]>=v)
				sum=(sum+f[i-1][p])%mod,++p;
			f[i][j]=sum;
		}
	}
	for(int i=k;i<=n;++i)
		ans[v]=(ans[v]+f[k][i])%mod;
}
signed main(){
	scanf("%lld%lld",&n,&k);
	for(int i=1;i<=n;++i)
		scanf("%lld",&a[i]),maxn=max(maxn,a[i]);
	sort(a+1,a+n+1);
	for(int v=1;v*(k-1)<=maxn;++v){
		work(v);
	}
	for(int i=1;i*(k-1)<=maxn;++i)
		sum=(sum+(ans[i]-ans[i+1]+mod)%mod*i%mod)%mod;
	printf("%lld\n",sum); 
	fclose(stdin);fclose(stdout);
	return 0;
}
posted @   andy_lz  阅读(8)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示