Emiya家的饭

\(Emiya\)家的饭

给你一个矩阵,要求每行只能选一个节点,每列选的节点不能超过所有选的节点的一半,不能不选,给出每个节点选择方案数,求总方案数

大暴力

int n,m,a[maxn][maxm],cnt[maxm],ans = 0;
void dfs(int x,int sum){//x烹饪方法编号 sum总方案数 
	if(x > n){
		long long num = 0;
		for(re int i = 1;i <= m;i++) num += cnt[i];
		num /= 2;
		for(re int i = 1;i <= m;i++)
		if(cnt[i] > num) return;
		ans += sum; ans %= mod;
		return;
	}
	dfs(x + 1,sum);
	for(re int i = 1;i <= m;i++){
		if(a[x][i] == 0) continue;
		cnt[i] ++;
		dfs(x + 1,sum * a[x][i] % mod);
		cnt[i] --;
	}
}
int main(){
	n = read(); m = read();
	for(re int i = 1;i <= n;i++)
		for(re int j = 1;j <= m;j++)
			a[i][j] = read();
	dfs(1,1);
	printf("%d",ans-1);
}

考虑列的限制,若有不合法的列,则必然有且只有一列不合法,因为不可能有不同的两列数量都超过总数一半

列的限制需要容斥计算:每行选不超过一个的方案数 - 每行选不超过一个,且某一列选了超过一半的方案数

那么考虑枚举不合法的一列。假设我们已经枚举了不合法的列为 \(col\),接下来会发现我们只关心一个数的位置是否在当前列;如果属于在其他列的情况,那么它具体在哪一列对当前列的合法性并无影响,我们并不需要考虑。

\[\large f_{i,j,k}=f_{i-1,j}+a_{i,col}\times f_{i-1,j-1,k}+(s_i-a_{i,col}\times f_{i-1,j,k-1}) \]

状态\(O(n^3)\)\(O(1)\)转移,枚举\(col\)复杂度,\(O(mn^3)\),有\(84\)分好成绩

统计不合法的方案数\(\large \sum\limits_{j>k}f_{n,j,k}\)

总方案数,设\(\large g_{i,j}\),前\(i\)行共选了\(j\)个数,有转移

\[\large g_{i,j}=g_{i-1,j}+s_i*g_{i-1,j-1} \]

\(\sum\limits_{i=1}^ng_{n,i}\)就是总方案数,\(O(n^2)\)

但是我们转移时不用关注\(j,k\)具体值,只需要相对大小,所以把状态改为\(f_{i,j}\),表示前\(i\)行,当前列的数比其他列多了\(j\)个,转移

\[\large f_{i,j}=f_{i-1,j}+a_{i,col}\times f_{i-1,j-1}+(s_i-a_{i,col}\times f_{i-1,j+1}) \]

\(O(mn^2)\)

#define ll long long
int n,m,a[N][M],sum[N][M];
ll f[N][N<<1],g[N][N];
int main(){
	scanf("%d%d",&n,&m);
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= m;++j){
			a[i][j] = read();
			sum[i][0] = (sum[i][0] + a[i][j]) % mod;
		}
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= m;++j)
			sum[i][j] = (sum[i][0] - a[i][j] + mod) % mod;
	ll ans = 0;
	for(int col = 1;col <= m;++col){
		memset(f,0,sizeof(f));f[0][n] = 1;
		for(int i = 1;i <= n;++i)
			for(int j = n-i;j <= n+i;++j)
				f[i][j] = (f[i-1][j]+f[i-1][j-1]*a[i][col]%mod+f[i-1][j+1]*sum[i][col]%mod)%mod;
		for(int j = 1;j <= n;++j)
			ans = (ans + f[n][n+j]) % mod;
	}
	g[0][0] = 1;
	for(int i = 1;i <= n;++i)
		for(int j = 0;j <= n;++j)
			g[i][j] = (g[i-1][j] + (j > 0 ? g[i-1][j-1]*sum[i][0]%mod : 0)) % mod;
	for(int i = 1;i <= n;++i)
		ans = (ans - g[n][i] + mod) % mod;
	printf("%lld",ans*(mod-1)%mod);
}

注意值域

posted @ 2020-10-29 21:18  INFP  阅读(55)  评论(0编辑  收藏  举报