CSP-2019 Emiya家今天的饭 笔记

CSP-2019 Day2 T1 几乎爆零
谁叫自己本来就笨还不怎么会DP

形式化描述这题:
给定一个 \(N\times M\) 的矩阵 \(m\),每一个元素都会有 \(m_{i,j}\) 种选法,现在添加三个限制

  1. 最少选择一个元素
  2. 每一行只能选择一个元素
  3. 每一列选择的元素个数不能超过选择总数的一半(下取整)

问,在这三个限制同时满足的情况之下有多少种方案,答案对 \(998244353\) 取模

先不管别的,总方案数我们是知道的。
因为每一行只能选 \(1\) 个,设 \(sum[i]\) 为第 \(i\) 行元素值的和,那么总方案数就是 \(\prod\limits_{i=1}^{N} (sum[i]+1) -1\) (加一是因为这一行可以不选,最后减一是因为至少要选取 \(1\) 个)
然后考虑比较难一点的第三个限制。

假设我们在其他列选取了 \(k\) 个元素,当前列选取了 \(j\) 个,那么必须满足 \(j\le k\)

那么暴力就出来了

\(F[i][j][k]\) 表示 前 \(i\) 行中 第 \(l\) 列选取了 \(j\) 个元素,非第 \(l\) 列选取了 \(k\) 个元素时的方案数

不太好转移,因为要同时保证第 \(l\) 列和其他的列选取元素个数均不超过选取元素的一半.

正难反易,数学老师的话还是有用的.

\(F[i][j][k]\) 表示 前 \(i\) 行中 第 \(l\) 列选取了 \(j\) 个元素,非第 \(l\) 列选取了 \(k\) 个元素时的不合法方案数

因为最多只能有 \(1\) 个列不合法,所以我们枚举这一个不合法列保证这个列选取元素个数大于其他列的总选取数.

我们知道怎么计算总方案数,计算出来减去不合法的即可.

状态转移方程为 \(F[i][j][k]=F[i-1][j][k-1]\times(sum[i]-m[i][l])+F[i-1][j][k]+F[i-1][j-1][k]\times m[i][l]\)
(需要满足 \(j>k\))

枚举第 \(l\) 列,嵌套枚举 \(i\),嵌套枚举 \(j\),嵌套枚举 \(k\),时间复杂度高达 \(O(MN^3)\),要优化.

怎么优化呢?

其实在发现暴力的方程中转移的条件是 \(j>k\) 的时候就能打开思路.

我们设 \(d=j-k\) 那么只要 \(d\le 0\) 的时候就是合法的,\(d>0\) 的时候就是不合法的.

然后我们可以把暴力中 \(j\)\(k\) 的两维合并为一维.

新状态: \(F[i][d]\) 表示前 \(i\) 行中第 \(l\) 列选取的元素和其他列选取的元素个数之差是 \(d\) 个时的不合法方案数.

状态转移方程:
\(F[i][d]=F[i-1][d](\texttt{第 }i\texttt{ 列完全不选})+F[i-1][d-1]\times m[i][l](\texttt{选取第 }i\texttt{ 行的第 }l\texttt{ 列})+F[i-1][d+1]\times(sum[i]-m[i][l])(\texttt{选取第 }i\texttt{ 行的其他列})\)

\(d=j-k\) 可以知道,\(d\in\left[-i,i\right]\),C++没有负数下标我们就直接把方程中的 \(d\) 加上一个 \(N\).同时把数组开大一点就可以.

对于每一个 \(l\),使用上面的方法计算完了之后,我们需要统计不合法的方案,也就是 \(d>0\) 的部分,在我们加了 \(N\) 的情况下,不合法的 \(d\) 的范围为 \(\left(N,2\times N\right]\)

所以计算完每一列不合法的方案数之后,都求一次 \(\sum\limits_{q=N+1}^{2\times N} F[N][q]\) 累加起来即为所有的不合法方案数.

注意long long,注意负数,注意取膜.

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long LL;
const int SAC=998244353;
int N,M;
LL sum[107],m[107][2007];
LL F[107][207];
LL ans,non;
int main()
{
	scanf("%d%d",&N,&M);
	for(int i=1;i<=N;i++)
		for(int j=1;j<=M;j++)
			scanf("%lld",&m[i][j]),sum[i]+=m[i][j],sum[i]=(sum[i]%SAC+SAC)%SAC;
	ans=1ll;
	for(int i=1;i<=N;i++)
		ans*=(sum[i]+1),ans%=SAC;
	ans=((ans-1)%SAC+SAC)%SAC;
	for(int j=1;j<=M;j++)//超出限制的列
	{
		memset(F,0,sizeof(F));
		F[0][N]=1;
		for(int i=1;i<=N;i++)
		{
			for(int d=-i;d<=i;d++)//枚举差值 
			{
				F[i][d+N]=(F[i-1][d+N]//不选i行
				+F[i-1][d+N-1]*m[i][j]//选取第i行第j列
				+F[i-1][d+N+1]*(sum[i]-m[i][j]))%SAC;//选取第i行其他列 
			}
		}
		for(int d=N+1;d<=(N<<1);d++)
			non+=F[N][d],non%=SAC;
		non=(non+SAC)%SAC;
	}
	printf("%lld",((ans-non)%SAC+SAC)%SAC);
	return 0;
}
posted @ 2020-05-22 23:47  ShadderLeave  阅读(131)  评论(2编辑  收藏  举报