csp2019 Emiya家今天的饭题解
qwq
由于窝太菜了,实在是不会,所以在题解的帮助下过掉了这道题。
写此博客来整理一下思路
正文
传送
简化一下题意:现在有\(n\)行\(m\)列数,选\(k\)个数的合法方案需满足:
1.一行最多选一个
2.一列最多选\(\lfloor \frac{k}{2} \rfloor\)个数
当然,如果你在某一行里选了0,就相当于没有在这一行里选数
选一次对答案的贡献是你选的所有不为零的数的乘积。对于任意的\(k\),只要有合法方案,就能取。
(希望没有把题目变得更复杂叭)
根据上面的要求,我们发现\(k\)的取值范围是\([1,n]\)。而且根据要求2,如果某个方案在满足1的前提下,是不合法的,那么这个方案里面一定有且仅有1列选了超出\(\lfloor \frac{k}{2} \rfloor\)个数,因为不可能有两列选的数同时超过\(\lfloor \frac{k}{2} \rfloor\)个。我们现在知道了不合法方案的一个特征,那么我们不妨试试总方案数-不合法方案数这个思路。
因为 满足1情况的总方案数-满足1而且不合法的方案数=乱选方案数-不满足1或不满足2的方案数 ,所以我们接下来计算方案数都在满足1的条件下来计算。
计算总方案数:设\(all[i][j]\)表示前\(i\)行,每行至多选一个,一共选了\(j\)个的方案数,那么\(all[i][j]=all[i-1][j]+\sum_{l=1}^m{all[i-1][j-1] \times a[i][l]}\)。用\(sum[i]\)表示第\(i\)行所有数的和,那么\(all[i][j]=all[i-1][j]+all[i-1][j-1] \times sum[i]\)
我们再来看看不合法方案怎么算。上面说到一个不合法方案一定只有1行选的数超过了\(\lfloor \frac{k}{2} \rfloor\)个,所以我们可以枚举每一列。但是我们不知道\(k\)。那么我们可以设\(no[i][j][l]\)表示前\(i\)行,该列选了\(j\)个,其他列选了\(l\)个。\(no[i][j][l]=no[i-1][j][l]+no[i-1][j-1][l]\times a[i][j]+no[i-1][j][l-1] \times(sum[i]-a[i][j])\)这样就可以由\(j,l\)确定唯一的\(k\)。枚举列:\(O(m)\),枚举\(i\):\(O(n)\),因为选数的个数最多是\(n\),所以枚举\(j,l\)都是\(O(n)\),总复杂度\(O(mn^3)\)
显然是不够的,需要优化。发现我们其实并不需要具体的\(k\),只需要知道当前列和其他列选的数的差值即可。为什么呢?不妨设当前列选的数为\(x+j\)个,其他列选的数为\(x\)个,那么一共选的数就是\(2x+j\)个。这里\(x\)取值任意(只要合法就行),所以可以\(2x+j\)覆盖所有的\(k\)。所以设\(no[i][j]\)表示前\(i\)行,当前枚举的列比其他列多选了\(j\)个的方案数。
\(no[i][j]=no[i-1][j]+no[i-1][j-1] \times a[i][j]+no[i-1][j+1] \times (sum[i]-a[i][j])\)
注意这里有个坑:枚举到第\(i\)行的时候,当前列最多会比其他列少\(i\)个数,所以\(j\)应该从\(-i\)开始枚举,而不是0。考虑到不能出现负下标,所以在代码中将每个下标+n。
如果这个方案是不合法方案,那么对应的\(j\)一定大于0。
最终答案就是\(sum_{j=1}^n{all[n][j]}-\sum_{j=1}^n{no[n][j]}\)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline ll read()
{
char ch=getchar();
ll x=0;bool f=0;
while(ch<'0'||ch>'9')
{
if(ch=='-') f=1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return f?-x:x;
}
const ll mod=998244353;
ll n,m,a[209][2009],sum[109],all[209][2009];
ll no[109][2109];
ll ans;
int main()
{
n=read();m=read();
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
a[i][j]=read(),sum[i]=(sum[i]+a[i][j])%mod;//保险起见随时随地模一下
all[0][0]=1;
for(int i=1;i<=n;i++)
for(int j=0;j<=n;j++)
all[i][j]=(all[i-1][j]+all[i-1][j-1]*sum[i]%mod+mod)%mod;
for(int j=1;j<=n;j++)
ans=(ans+all[n][j])%mod;
for(int lie=1;lie<=m;lie++)
{
memset(no,0,sizeof(no));
no[0][n]=1;
for(int i=1;i<=n;i++)
{
for(int j=n-i;j<=n+i;j++)
{
no[i][j]=(no[i-1][j]+no[i-1][j-1]*a[i][lie]%mod+no[i-1][j+1]*(sum[i]-a[i][lie])%mod+mod)%mod;
}
}
for(int j=n+1;j<=2*n;j++)
ans=(ans-no[n][j]+mod)%mod;
}
cout<<ans;
return 0;
}