D1 - The Endspeaker (Easy Version)

Posted on 2024-10-28 23:04  Capterlliar  阅读(25)  评论(0编辑  收藏  举报

题意:给出长为n的数组a,长为m的数组b和数字k,k初始值为1。每次可以执行以下两种操作之一:

1. 当k<=m时,k++;

2. 删除a前缀和小于等于b[k]的部分,同时cost+=m-k;

求删完a的最小cost;如果不能删完a,输出-1.

解:首先a最大值大于b[1]时无解。一开始想的是贪心,对于每一段a[i...j],如果其max(a[j...n])>=b[k+1]且sum(a[j...n])<=b[k],那么算一次贡献,否则寻找最合适的k。然后WA test3了,想了一下是因为可以选择更大的b[k],使得一次删除的数更多,从而贡献更少。

那么只能dp了。令dp[i][j]为删到第i个数,k=j时贡献最小值,那么有:

dp[i][j]=min(dp[i][j],dp[p][k]+m-j)         where sum[p+1...i]<=b[j] && k<=j
可以看到p<i,那么是从p<i且k<=j的一个矩阵里转移。首先p应该选择尽可能靠左的数,即sum[p+1...i]这一段应尽可能大,这可以用二分解决。接下来p固定了,只需要从该行中选择最小的数,于是每次更新的时候同时维护每行最小值即可。
代码:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define maxx 300005
#define inf 0x3f3f3f3f
int n,m;
int a[maxx]={0},b[maxx]={0};
ll sum[maxx]={0};
signed main(){
    int T;
    scanf("%d", &T);
    while (T--){
        scanf("%d%d",&n,&m);
        int maxa=0;
        for(int i=1;i<=n;i++){
            scanf("%d",&a[i]);
            maxa=max(maxa,a[i]);
            sum[i]=sum[i-1]+a[i];
        }
        for(int i=1;i<=m;i++){
            scanf("%d",&b[i]);
        }
        if(maxa>b[1]){
            printf("-1\n");
            continue;
        }
        vector<vector<ll>> dp(n+1,vector<ll>(m+1,inf));
        vector<vector<ll>> p(n+1,vector<ll>(m+1,inf));
        for(int i=1;i<=m;i++) dp[0][i]=0;
        for(int i=1;i<=m;i++) p[0][i]=0;

        for(int i=1;i<=n;i++) {
            for(int j=1;j<=m;j++){
                if(a[i]>b[j]) break;
                int l=1,r=i,mid,ans;
                while(l<=r){
                    mid=l+(r-l)/2;
                    if(sum[i]-sum[mid-1]<=b[j]){
                        ans=mid;
                        r=mid-1;
                    }
                    else
                        l=mid+1;
                }
                dp[i][j]=min(dp[i][j],p[ans-1][j]+m-j); 
            }
            for(int j=1;j<=m;j++){
                p[i][j]=min(dp[i][j],p[i][j-1]);
            }
        }
        // for(int i=1;i<=n;i++){
        //     for(int j=1;j<=m;j++){
        //         printf("%lld ",dp[i][j]);
        //     }
        //     printf("\n");
        // }
        ll ans=inf;
        for(int j=1;j<=m;j++) ans=min(ans,dp[n][j]);
        printf("%lld\n",ans);
        // printf("\n");
    }
}

// dp[i][j] to pos i, with k=j, min cost
// dp[i][j]=min(dp[i][j],dp[p][k]+m-j) where sum[p+1...i]<=b[j] && k<=j
// p越靠左越好,可二分
// 
View Code