[AGC019E]Shuffle and Swap
题目
点这里看题目。
分析
题目明显是要求我们求方案数。
显然这道题没有办法直接做。
考虑转化一下题目条件。可以发现我们应该让 \(A\) 中多余的 1 换到 \(A\) 中缺少 1 的位置去。为了使描述更加清晰,我们这样定义:
-
公共点(\(P\)):满足 \(A_i=1\land B_i=1\) 的 \(i\)。可以发现无论怎么交换,最终 \(P\) 上总是 1 。
-
起点(\(S\)):满足 \(A_i=0\land B_i=1\) 的 \(i\)。我们需要将 \(S\) 上的 0 转移走。
-
终点(\(E\)):满足 \(A_i=1\land B_i=0\) 的 \(i\)。我们需要将 \(S\) 上的 0 转移到 \(E\) 上来。
可以发现,最终可以使得 \(A=B\) 的操作序列必然满足:
连接边 \((a_i,b_i)\),则图的形态应该是一大堆由 \(S\) 和 \(E\) 作为端点,\(P\) 作为中间点的链。
注意这里的链应该是“有向”的,即我们不能倒着操作一条链。
好的,这样已经清晰多了。我们考虑写出状态和转移:
\(f(i,j)\):使用了 \(i\) 个 \(P\) ,组成了 \(j\) 条链的真实序列方案数。
不难考虑转移:
-
加入一个新的 \(P\) 。首先我们应该选取它所在的链(\(j\)),钦定它在末尾,再考虑它的标号(\(i\))。此时的贡献就是 \(f(i-1,j)\times i\times j\)。
-
加入一个新的链。我们继续钦定它放在末尾,并且考虑 \(S\) 和 \(E\) 的标号(\(j^2\))。此时的贡献就是 \(f(i,j-1)\times j^2\)。
需要注意的是,每次转移必然会导致真实序列(也就是 \(a\) 和 \(b\))长度加一。我们同样钦定每次新增后放在末尾。
真实情况下,一条链可能会有许多种对应的真实序列,而同一条链的不同的真实序列是由不同转移顺序来区分的。
于是就有转移:
考虑统计答案。注意我们不一定要所有的 \(P\) 都在 \(S-E\) 链上。因此我们需要枚举一下不在链上的 \(P\) 的数量。
设 \(P\) 点有 \(s\) 个,\(S\) 和 \(E\) 各有 \(t\) 个。
因此有答案为:
其中 \(\binom{s}{i}\times (i!)^2\) 是在计算不在链上的 \(P\) 的带标号形态, \(\binom{s+t}{i}\) 是在合并两个序列。
最后我们就得到了时间为 \(O(n^2)\) 的算法。
本题一些有价值的点:
-
考虑序列的相同与不同,于是就有了 \(P,S,E\) 三种点。
-
将交换看成边,发现链的性质。同时这也令人想起 树上的数 。
-
考虑序列的 DP 的时候,要么考虑标号,要么考虑位置,同时考虑会算重。
代码
#include <cstdio>
const int mod = 998244353;
const int MAXN = 10005;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXN];
int fac[MAXN], ifac[MAXN];
char A[MAXN], B[MAXN];
int N;
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
void init( const int siz )
{
fac[0] = 1;
for( int i = 1 ; i <= siz ; i ++ ) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[siz] = qkpow( fac[siz], mod - 2 );
for( int i = siz - 1 ; ~ i ; i -- ) ifac[i] = 1ll * ifac[i + 1] * ( i + 1 ) % mod;
}
int C( const int n, const int m )
{
if( n < m || n < 0 || m < 0 ) return 0;
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int main()
{
int S = 0, T = 0;
scanf( "%s%s", A + 1, B + 1 );
for( N = 1 ; A[N] ; N ++ )
{
int a = A[N] - '0', b = B[N] - '0';
if( a && b ) S ++;
if( a && ! b ) T ++;
}
init( N );
f[0][0] = 1;
for( int i = 0 ; i <= S ; i ++ )
for( int j = 0 ; j <= T ; j ++ )
{
if( i ) add( f[i][j], 1ll * f[i - 1][j] * i % mod * j % mod );
if( j ) add( f[i][j], 1ll * f[i][j - 1] * j % mod * j % mod );
}
int ans = 0;
for( int i = 0 ; i <= S ; i ++ )
add( ans, 1ll * C( S + T, i ) * C( S, i ) % mod * fac[i] % mod * fac[i] % mod * f[S - i][T] % mod );
write( ans ), putchar( '\n' );
return 0;
}