[AGC019E]Shuffle and Swap

题目

点这里看题目。

分析

题目明显是要求我们求方案数。

显然这道题没有办法直接做。

考虑转化一下题目条件。可以发现我们应该让 \(A\) 中多余的 1 换到 \(A\) 中缺少 1 的位置去。为了使描述更加清晰,我们这样定义:

  1. 公共点(\(P\)):满足 \(A_i=1\land B_i=1\)\(i\)。可以发现无论怎么交换,最终 \(P\) 上总是 1 。

  2. 起点(\(S\)):满足 \(A_i=0\land B_i=1\)\(i\)。我们需要将 \(S\) 上的 0 转移走。

  3. 终点(\(E\)):满足 \(A_i=1\land B_i=0\)\(i\)。我们需要将 \(S\) 上的 0 转移到 \(E\) 上来。

可以发现,最终可以使得 \(A=B\) 的操作序列必然满足:

连接边 \((a_i,b_i)\),则图的形态应该是一大堆\(S\)\(E\) 作为端点,\(P\) 作为中间点的链

注意这里的链应该是“有向”的,即我们不能倒着操作一条链。

好的,这样已经清晰多了。我们考虑写出状态和转移:

\(f(i,j)\):使用了 \(i\)\(P\) ,组成了 \(j\) 条链的真实序列方案数。

不难考虑转移:

  1. 加入一个新的 \(P\) 。首先我们应该选取它所在的链(\(j\)),钦定它在末尾,再考虑它的标号(\(i\))。此时的贡献就是 \(f(i-1,j)\times i\times j\)

  2. 加入一个新的链。我们继续钦定它放在末尾,并且考虑 \(S\)\(E\) 的标号(\(j^2\))。此时的贡献就是 \(f(i,j-1)\times j^2\)

需要注意的是,每次转移必然会导致真实序列(也就是 \(a\)\(b\))长度加一。我们同样钦定每次新增后放在末尾

真实情况下,一条链可能会有许多种对应的真实序列,而同一条链的不同的真实序列是由不同转移顺序来区分的。

于是就有转移:

\[f(i,j)=i\times j\times f(i-1,j)+j^2\times f(i,j-1) \]


考虑统计答案。注意我们不一定要所有的 \(P\) 都在 \(S-E\) 链上。因此我们需要枚举一下不在链上的 \(P\) 的数量。

\(P\) 点有 \(s\) 个,\(S\)\(E\) 各有 \(t\) 个。

因此有答案为:

\[\sum_{i=0}^s\binom{s}{i}\times (i!)^2\times f(s-i,t)\times \binom{s+t}{i} \]

其中 \(\binom{s}{i}\times (i!)^2\) 是在计算不在链上的 \(P\)带标号形态\(\binom{s+t}{i}\) 是在合并两个序列。

最后我们就得到了时间为 \(O(n^2)\) 的算法。

本题一些有价值的点:

  1. 考虑序列的相同与不同,于是就有了 \(P,S,E\) 三种点。

  2. 将交换看成边,发现链的性质。同时这也令人想起 树上的数

  3. 考虑序列的 DP 的时候,要么考虑标号,要么考虑位置,同时考虑会算重。

代码

AC 记录

#include <cstdio>

const int mod = 998244353;
const int MAXN = 10005;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

int f[MAXN][MAXN];
int fac[MAXN], ifac[MAXN];
char A[MAXN], B[MAXN];
int N;

int qkpow( int base, int indx )
{
	int ret = 1;
	while( indx )
	{
		if( indx & 1 ) ret = 1ll * ret * base % mod;
		base = 1ll * base * base % mod, indx >>= 1;
	}
	return ret;
}

void init( const int siz )
{
	fac[0] = 1;
	for( int i = 1 ; i <= siz ; i ++ ) fac[i] = 1ll * fac[i - 1] * i % mod;
	ifac[siz] = qkpow( fac[siz], mod - 2 );
	for( int i = siz - 1 ; ~ i ; i -- ) ifac[i] = 1ll * ifac[i + 1] * ( i + 1 ) % mod;
}

int C( const int n, const int m ) 
{ 
	if( n < m || n < 0 || m < 0 ) return 0;
	return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }

int main()
{
	int S = 0, T = 0;
	scanf( "%s%s", A + 1, B + 1 );
	for( N = 1 ; A[N] ; N ++ )
	{
		int a = A[N] - '0', b = B[N] - '0';
		if( a && b ) S ++;
		if( a && ! b ) T ++;
	}
	init( N );
	f[0][0] = 1;
	for( int i = 0 ; i <= S ; i ++ )
		for( int j = 0 ; j <= T ; j ++ )
		{
			if( i ) add( f[i][j], 1ll * f[i - 1][j] * i % mod * j % mod );
			if( j ) add( f[i][j], 1ll * f[i][j - 1] * j % mod * j % mod );
		}
	int ans = 0;
	for( int i = 0 ; i <= S ; i ++ )
		add( ans, 1ll * C( S + T, i ) * C( S, i ) % mod * fac[i] % mod * fac[i] % mod * f[S - i][T] % mod );
	write( ans ), putchar( '\n' );
	return 0;
}
posted @ 2020-08-15 16:31  crashed  阅读(118)  评论(0编辑  收藏  举报