Emiya家的饭
\(Emiya\)家的饭
给你一个矩阵,要求每行只能选一个节点,每列选的节点不能超过所有选的节点的一半,不能不选,给出每个节点选择方案数,求总方案数
大暴力
int n,m,a[maxn][maxm],cnt[maxm],ans = 0;
void dfs(int x,int sum){//x烹饪方法编号 sum总方案数
if(x > n){
long long num = 0;
for(re int i = 1;i <= m;i++) num += cnt[i];
num /= 2;
for(re int i = 1;i <= m;i++)
if(cnt[i] > num) return;
ans += sum; ans %= mod;
return;
}
dfs(x + 1,sum);
for(re int i = 1;i <= m;i++){
if(a[x][i] == 0) continue;
cnt[i] ++;
dfs(x + 1,sum * a[x][i] % mod);
cnt[i] --;
}
}
int main(){
n = read(); m = read();
for(re int i = 1;i <= n;i++)
for(re int j = 1;j <= m;j++)
a[i][j] = read();
dfs(1,1);
printf("%d",ans-1);
}
考虑列的限制,若有不合法的列,则必然有且只有一列不合法,因为不可能有不同的两列数量都超过总数一半
列的限制需要容斥计算:每行选不超过一个的方案数 - 每行选不超过一个,且某一列选了超过一半的方案数
那么考虑枚举不合法的一列。假设我们已经枚举了不合法的列为 \(col\),接下来会发现我们只关心一个数的位置是否在当前列;如果属于在其他列的情况,那么它具体在哪一列对当前列的合法性并无影响,我们并不需要考虑。
\[\large f_{i,j,k}=f_{i-1,j}+a_{i,col}\times f_{i-1,j-1,k}+(s_i-a_{i,col}\times f_{i-1,j,k-1})
\]
状态\(O(n^3)\),\(O(1)\)转移,枚举\(col\)复杂度,\(O(mn^3)\),有\(84\)分好成绩
统计不合法的方案数\(\large \sum\limits_{j>k}f_{n,j,k}\)
总方案数,设\(\large g_{i,j}\),前\(i\)行共选了\(j\)个数,有转移
\[\large g_{i,j}=g_{i-1,j}+s_i*g_{i-1,j-1}
\]
\(\sum\limits_{i=1}^ng_{n,i}\)就是总方案数,\(O(n^2)\)
但是我们转移时不用关注\(j,k\)具体值,只需要相对大小,所以把状态改为\(f_{i,j}\),表示前\(i\)行,当前列的数比其他列多了\(j\)个,转移
\[\large f_{i,j}=f_{i-1,j}+a_{i,col}\times f_{i-1,j-1}+(s_i-a_{i,col}\times f_{i-1,j+1})
\]
\(O(mn^2)\)
#define ll long long
int n,m,a[N][M],sum[N][M];
ll f[N][N<<1],g[N][N];
int main(){
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;++i)
for(int j = 1;j <= m;++j){
a[i][j] = read();
sum[i][0] = (sum[i][0] + a[i][j]) % mod;
}
for(int i = 1;i <= n;++i)
for(int j = 1;j <= m;++j)
sum[i][j] = (sum[i][0] - a[i][j] + mod) % mod;
ll ans = 0;
for(int col = 1;col <= m;++col){
memset(f,0,sizeof(f));f[0][n] = 1;
for(int i = 1;i <= n;++i)
for(int j = n-i;j <= n+i;++j)
f[i][j] = (f[i-1][j]+f[i-1][j-1]*a[i][col]%mod+f[i-1][j+1]*sum[i][col]%mod)%mod;
for(int j = 1;j <= n;++j)
ans = (ans + f[n][n+j]) % mod;
}
g[0][0] = 1;
for(int i = 1;i <= n;++i)
for(int j = 0;j <= n;++j)
g[i][j] = (g[i-1][j] + (j > 0 ? g[i-1][j-1]*sum[i][0]%mod : 0)) % mod;
for(int i = 1;i <= n;++i)
ans = (ans - g[n][i] + mod) % mod;
printf("%lld",ans*(mod-1)%mod);
}
注意值域