Solution -「AGC 019E」「AT 2704」Shuffle and Swap
\(\mathcal{Description}\)
Link.
给定 \(01\) 序列 \(\{A_n\}\) 和 \(\{B_n\}\),其中 \(1\) 的个数均为 \(k\)。记 \(A\) 中 \(1\) 的位置为 \(\{a_k\}\),\(B\) 中的为 \(\{b_k\}\)。现任意排列 \(\{a_k\}\) 和 \(\{b_k\}\),然后依次交换 \(A_{a_i}\) 和 \(A_{b_i}\),\(i=1,2,\dots,k\)。求使操作完成后 \(A=B\) 的排列方案数对 \(998244353\) 取模的结果。
\(n\le10^4\)。
\(\mathcal{Solution}\)
首先转化题意——每次操作用 \(A\) 的一个 \(1\) 和 \(B\) 中一个需要 \(1\) 的位置交换,共进行 \(k\) 次操作,求满足条件的方案数。
我们定义 \(A_i=1\land B_i=0\) 的位置为“富余点”——它们需要为其它位置提供 \(1\);\(A_i=B_i=1\) 的位置为“公共点”——它们已经满足条件,但可以作为传递 \(1\) 的载体;\(A_i=0\land B_i=1\) 的位置为“缺失点”——它们是“富余点”需要提供到的位置。
那么,只需要每个“富余点”的 \(1\) 都传递给“缺失点”,\(A\) 就与 \(B\) 相等。考虑一次从一个“富余”到“缺失”的传递,如图:
其中三角形为“缺失点”,圆形为“公共点”(任意多个,亦可不存在),方形为“富余点”。注意若存在“公共点”,传递是不能反向的,这样会导致与三角形相邻的公共点的 \(1\) 变为 \(0\)。(某课件的笔误,注意一下 www。)
考虑 DP,记 \(f(i,j)\) 表示在传递链中用了 \(i\) 个“公共点”,用了 \(j\) 个“富余点”(即有 \(j\) 条传递链)时,传递链中的方案数。边界为 \(f(0,0)=1\),转移:
前一项,将一个公共点加入一条传递链的末尾。有 \(j\) 条链,新的点可以与已有的 \(i-1\) 个点交换(不是交换操作顺序,而是直接交换链中位置),故有系数 \(ij\)。
后一项,新建一条“缺失”-“富余”链。首先拿出新的一对“缺失点”和“富余点”,仍考虑到原来的“缺失点”或“富余点”可以和新点交换位置,故有系数 \(j^2\)。
最后统计答案,发现“公共点”没有必要在链中用完,设在链中用 \(a\) 个“公共点”,还剩下 \(b\) 个。那么在所有“公共点”中选出 \(b\) 个,方案数 \(\binom{a+b}{b}\);公共点内部方案 \((b!)^2\);把这 \(b\) 次操作安排进总共 \(k\) 次操作里,方案数 \(\binom{k}{b}\)。所以设“公共点”有 \(s\) 个,“富余点”有 \(t\) 个,答案为(\(i\) 枚举的是不在链中的“公共点”个数):
复杂度 \(\mathcal O(n^2)\)。
\(\mathcal{Code}\)
#include <cstdio>
#include <cstring>
const int MAXL = 1e4, MOD = 998244353;
int n, sur, bal, f[MAXL + 5][MAXL + 5];
int fac[MAXL + 5], ifac[MAXL + 5];
char A[MAXL + 5], B[MAXL + 5];
inline void addeq ( int& a, const int b ) { if ( ( a += b ) >= MOD ) a -= MOD; }
inline int mul ( long long a, const int b ) { return ( a *= b ) < MOD ? a : a % MOD; }
inline int qkpow ( int a, int b ) {
int ret = 1;
for ( ; b; a = mul ( a, a ), b >>= 1 ) ret = mul ( ret, b & 1 ? a : 1 );
return ret;
}
inline void init () {
fac[0] = 1;
for ( int i = 1; i <= n; ++ i ) fac[i] = mul ( i, fac[i - 1] );
ifac[n] = qkpow ( fac[n], MOD - 2 );
for ( int i = n - 1; ~ i; -- i ) ifac[i] = mul ( i + 1, ifac[i + 1] );
}
inline int comb ( const int n, const int m ) {
return n < m ? 0 : mul ( fac[n], mul ( ifac[m], ifac[n - m] ) );
}
int main () {
scanf ( "%s %s", A + 1, B + 1 );
n = strlen ( A + 1 ), init ();
for ( int i = 1; i <= n; ++ i ) {
if ( A[i] ^ '0' && B[i] ^ '0' ) ++ bal;
else if ( A[i] > B[i] ) ++ sur;
}
f[0][0] = 1;
for ( int i = 0; i <= bal; ++ i ) {
for ( int j = 0, cur; j <= sur; ++ j ) {
if ( ! ( cur = f[i][j] ) ) continue;
addeq ( f[i + 1][j], mul ( cur, mul ( i + 1, j ) ) );
addeq ( f[i][j + 1], mul ( cur, mul ( j + 1, j + 1 ) ) );
}
}
int ans = 0;
for ( int i = 0; i <= bal; ++ i ) {
int fre = bal - i;
int self = mul ( mul ( fac[fre], fac[fre] ),
mul ( comb ( bal + sur, fre ), comb ( bal, fre ) ) );
addeq ( ans, mul ( f[i][sur], self ) );
}
printf ( "%d\n", ans );
return 0;
}