AGC057E RowCol/ColRow Sort 【观察,组合计数】
考虑排序网络的 \(\texttt{01}\) 原理,合法当且仅当对每个 \(k\in[0,8]\),对 \([A_{i,j}\le k]\) 做操作都得到 \([B_{i,j}\le k]\)。
现在就对 \(\texttt{01}\) 矩阵排序,注意到,考虑每行之和 \(r_i\),对行排序不改变,对列排序就是对 \(r_i\) 降序排序,每列之和 \(c_j\) 同理。且结果只与 \(r_i,c_j\) 有关,所以是充分的。
设排列 \(p_k\) 使得 \(r_{p_k(i)}\ge r_{p_k(i+1)}\),\(q_k\) 同理,则条件即为 \(A_{i,j}\le k\iff B_{p_k(i),q_k(j)}\le k\),那么这堆排列就唯一确定了 \(A_{i,j}\),当然这会算重,最后除以一堆阶乘即可。
不关注 \(A\),条件就是 \(B_{p_k(i),q_k(j)}\le k\implies B_{p_{k+1}(i),q_{k+1}(j)}\le k+1\),这只跟 \(p_{k+1}\circ p_k^{-1}\) 和 \(q_{k+1}\circ q_k^{-1}\) 有关,不如看成 \(B_{i,j}\le k\implies B_{p_k(i),q_k(j)}\le k+1\)。
考察 \([B_{i,j}\le k]\) 的杨图结构,设 \(a_i=\sum_j[B_{i,j}\le k]\),\(b_j=\sum_i[B_{i,j}\le k+1]\),条件即为 \(p_k(i)\le b_{\max\{q_k(1),\cdots,q_k(a_i)\}}\),不等式右侧递增,所以即为枚举 \(q_k\),计算 \(\prod_{i=1}^n(b_{\max\{q_k(1),\cdots,q_k(a_i)\}}-i+1)\) 之和,其中 \(b_0=n\)。从小到大枚举 \(a_i\),进行一个 dp 就可以了。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1503, mod = 998244353;
int ksm(int a, int b){
int res = 1;
for(;b;b >>= 1, a = (LL)a * a % mod)
if(b & 1) res = (LL)res * a % mod;
return res;
}
int fac[N], rfac[N];
void init(int m){
*fac = 1;
for(int i = 1;i <= m;++ i) fac[i] = (LL)fac[i - 1] * i % mod;
rfac[m] = ksm(fac[m], mod - 2);
for(int i = m;i;-- i) rfac[i - 1] = (LL)rfac[i] * i % mod;
}
int n, m, ans = 1, a[10][N], b[10][N], cnt[N], f[N][N];
void solve(int k){
memset(f, 0, sizeof(f));
int t = n;
for(;t && !a[k][t];-- t) ans = ans * (n - t + 1ll) % mod;
for(int i = 1;i <= m;++ i) f[1][i] = 1;
for(;t && a[k][t] == 1;-- t)
for(int i = 1;i <= m;++ i)
f[1][i] = f[1][i] * max(b[k + 1][i] - t + 1ll, 0ll) % mod;
for(int i = 2;i <= m;++ i){
int tmp = f[i - 1][i - 1];
for(int j = i;j <= m;++ j){
tmp += f[i - 1][j]; if(tmp >= mod) tmp -= mod;
f[i][j] = (tmp + LL(j - i) * f[i - 1][j]) % mod;
}
for(;t && a[k][t] == i;-- t)
for(int j = i;j <= m;++ j)
f[i][j] = f[i][j] * max(b[k + 1][j] - t + 1ll, 0ll) % mod;
}
ans = (LL)ans * f[m][m] % mod;
memset(cnt, 0, (m + 1) << 2);
for(int i = 1;i <= n;++ i) ++ cnt[a[k][i]];
for(int i = 0;i <= m;++ i) ans = (LL)ans * rfac[cnt[i]] % mod;
memset(cnt, 0, (n + 1) << 2);
for(int i = 1;i <= m;++ i) ++ cnt[b[k][i]];
for(int i = 0;i <= n;++ i) ans = (LL)ans * rfac[cnt[i]] % mod;
}
int main(){
ios::sync_with_stdio(0);
cin >> n >> m; init(N - 1);
for(int i = 1;i <= n;++ i)
for(int j = 1, x;j <= m;++ j){
cin >> x; ++ a[x][i]; ++ b[x][j];
}
for(int i = 1;i <= 9;++ i){
for(int j = 1;j <= n;++ j) a[i][j] += a[i - 1][j];
for(int j = 1;j <= m;++ j) b[i][j] += b[i - 1][j];
}
for(int i = 0;i <= 8;++ i) solve(i);
printf("%d\n", ans);
}