【题解】P5322 [BJOI2019] 排兵布阵(DP,背包)
【题解】P5322 [BJOI2019] 排兵布阵
挺开心的,毕竟这是我为数不多自己做出的蓝题之一。
题目链接
题意概述
这道题的题意很清楚,所以这里直接摘抄原题题面。
小 C 正在玩一款排兵布阵的游戏。在游戏中有 \(n\) 座城堡,每局对战由两名玩家来争夺这些城堡。每名玩家有 \(m\) 名士兵,可以向第 \(i\) 座城堡派遣 \(a_i\) 名士兵去争夺这个城堡,使得总士兵数不超过 \(m\)。
如果一名玩家向第 \(i\) 座城堡派遣的士兵数严格大于对手派遣士兵数的两倍,那么这名玩家就占领 了这座城堡,获得 \(i\) 分。
现在小 C 即将和其他 \(s\) 名玩家两两对战,这 \(s\) 场对决的派遣士兵方案必须相同。小 C 通过某些途径得知了其他 \(s\) 名玩家即将使用的策略,他想知道他应该使用什么策略来最大化自己的总分。
由于答案可能不唯一,你只需要输出小 C 总分的最大值。
思路分析
刚开始想了一个前缀和差分套 01 背包的做法。
定义 \(sum_{i,j}\) 表示对于第 \(i\) 座城堡,派遣士兵 \(j\) 个时,占领了几次城堡。
那么对于一个输入的 \(a_{i,j}\),令 \(a_{i,j}\times 2+1=k\),然后 \(sum_{i,k}++\)。
最后对于每一个 \(i\),求一下 \(\sum \limits_{j=1}^m sum_{i,j}\),得到新的前缀和数组。
定义 \(dp_{i,j}\) 表示的是考虑到了第 \(i\) 座城堡,当前派遣士兵总数是 \(j\) 时,得分最大是多少。
可以发现实际上这是一个 01 背包,枚举第 \(i\) 个城堡派遣的士兵个数 \(k\),那么有:
直接转移即可。
分析一下时间复杂度:枚举 \(i\):\(O(n)\),枚举 \(j,k\):都是 \(O(m)\)。
所以总复杂度:\(O(nm^2)\)。复杂度炸没边了。
考虑如何优化。
如果 \(dp\) 状态不变,那么瓶颈在于枚举 \(k\) 上,因为枚举 \(i,j\) 的复杂度都无法改变。
观察数据范围可以发现,\(m\) 的范围虽然很大,但 \(n\) 却很小,能否将枚举 \(k\) 的复杂度降低到 \(O(n)\) 级别呢。
其实是可以的。
既然我们直接枚举士兵个数不成,不妨换个角度。我们枚举可以占领几次城堡,也就是在这个城堡上战胜了多少个敌人。
那么我们只需要知道战胜了 \(k\) 个敌人时的士兵个数即可。
显然选择士兵个数要遵循的原则是:能小则小。即在满足条件的情况下,选最少的士兵。
这个是显然的,因为如果你选的更大就会使得多选的士兵浪费,从而不一定达到最优解。
所以对于战胜了 \(k\) 个敌人的士兵个数,假设第 \(i\) 个敌人派遣到这座城堡的士兵个数是 \(a_i\),那么答案就是 \(a\) 数组中第 \(k\) 小的元素 \(\times 2+1\)。
那么我们对每一座城堡上所有敌人派遣的士兵个数排序,\(a_k\) 即为第 \(k\) 小。
综上,对于 \(dp_{i,j}\):
\(a_{i,k}\) 表示的是第 \(i\) 个城堡上所有敌人派遣第 \(k\) 小的士兵数量。
时间复杂度:\(O(n^2m)\)。
易错点
-
在刚开始输入的时候,由于 \(a_{i,j}\) 表示的是表示的是第 \(i\) 个城堡上所有敌人派遣第 \(k\) 小的士兵数量。但输入相对于这个是反的,所以刚开始应该输入的是 \(a_{j,i}\)。
-
对于一个 \(dp_{i,j}\),若没有使用滚动数组,则刚开始要继承 \(dp_{i-1,j}\)。
代码实现
//luoguP5322
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=105;
const int maxm=2e4+10;
int a[maxn][maxm<<1],dp[maxn][maxm<<1];
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int main()
{
int s,n,m;
s=read();n=read();m=read();
for(int i=1;i<=s;i++)
{
for(int j=1;j<=n;j++)
{
int x=read();
a[j][i]=x;//bug:a[i][j]=x;
}
}
for(int i=1;i<=n;i++)sort(a[i]+1,a[i]+s+1);
for(int i=1;i<=n;i++)
{
for(int j=m;j>=0;j--)
{
dp[i][j]=dp[i-1][j];//bug:forget
//要在枚举 k 的循环外继承。
for(int k=1;k<=s;k++)
{
if(j>=a[i][k]*2+1)dp[i][j]=max(dp[i][j],dp[i-1][j-a[i][k]*2-1]+k*i);
}
}
}
int ans=0;
for(int i=0;i<=m;i++)ans=max(ans,dp[n][i]);
cout<<ans<<'\n';
return 0;
}