[AGC019E]Shuffle and Swap

[AGC019E]Shuffle and Swap

题目大意:

给出两个长度为\(n(n\le10000)\)\(01\)\(A_{1\sim n}\)\(B_{1\sim n}\)。两个串均有\(k\)'1'。令\(a_{1\sim k}\)\(b_{1\sim k}\)分别表示\(A\)\(B\)中所有'1'出现的位置。将\(a\)\(b\)等概率随机排列,按\(1\sim k\)的顺序交换\(A_{a_i}\)\(A_{b_i}\)。令\(P\)表示操作完成后\(A\)\(B\)相等的概率,求\(P\times(k!)^2\)在模\(998244353\)意义下的值。

思路:

\(k\)'1'分为两类:\(S_1\)\(S_2\)\(S_1\)表示\(A\)\(B\)中均为'1',有\(s_1\)个;\(S_2\)表示\(A\)\(B\)中只有一个为'1',有\(s_2\)个。用\(f[i][j]\)表示\(S_1\)交换完\(i\)个,\(S_2\)交换完\(j\)个后的方案数。

显然,初始情况下,\(S_1\)可以交换\(2\)次,\(S_2\)可以交换\(1\)次。方便起见,假设我们的\(A\)\(B\)分别为:

110011
111100

对于\(i=0\)的初始情况,由于\(S_2\)之间任意配对都可以使得\(A=B\),因此有\((j!)^2\)种方案。

若我们交换一对\(S_2\),如\(4\)\(5\),那么交换后,可以交换\(1\)次的\(S_2\)减少了一对。有转移\(f[i][j]+=f[i][j-1]\times j^2\)

若我们交换一个\(S_1\)和一个\(S_2\),如\(2\)\(6\),那么交换后,原来那个\(S_2\)已经不能再交换了,而原来可以交换\(2\)次的\(S_1\)现在只能交换一次,可以算入\(S_2\)中,也就是相当于减少了一个\(S_1\)。有转移\(f[i][j]+=f[i-1][j]\times i\times j\)

最后统计答案时,对于\(S_1\)没有用完的情况,无论如何配对都能满足\(A=B\)。因此\(f[s_1-i][s_2]\)对答案的贡献为\(f[s_1-i][s_2]\times(i!)^2\times\binom{s_1}i\times\binom ki\)

时间复杂度\(\mathcal O(n^2)\)

源代码:

#include<cstdio>
#include<cstring>
using int64=long long;
constexpr int N=10001,mod=998244353;
char a[N],b[N];
int fact[N],factinv[N],f[N][N];
void exgcd(const int &a,const int &b,int &x,int &y) {
	if(!b) {
		x=1,y=0;
		return;
	}
	exgcd(b,a%b,y,x);
	y-=a/b*x;
}
inline int inv(const int &x) {
	int ret,tmp;
	exgcd(x,mod,ret,tmp);
	return (ret%mod+mod)%mod;
}
inline int C(const int &n,const int &m) {
	if(n<m||n<0||m<0) return 0;
	return (int64)fact[n]*factinv[m]%mod*factinv[n-m]%mod;
}
int main() {
	scanf("%s%s",a,b);
	const int n=strlen(a);
	for(register int i=fact[0]=1;i<=n;i++) {
		fact[i]=(int64)fact[i-1]*i%mod;
	}
	factinv[n]=inv(fact[n]);
	for(register int i=n;i;i--) {
		factinv[i-1]=(int64)factinv[i]*i%mod;
	}
	int s1=0,s2=0;
	for(register int i=0;i<n;i++) {
		if(a[i]=='1'&&b[i]=='1') s1++;
		if(a[i]=='1'&&b[i]=='0') s2++;
	}
	for(register int i=0;i<=s2;i++) {
		f[0][i]=(int64)fact[i]*fact[i]%mod;
	}
	for(register int i=1;i<=s1;i++) {
		for(register int j=1;j<=s2;j++) {
			f[i][j]=((int64)f[i-1][j]*i%mod*j%mod+(int64)f[i][j-1]*j%mod*j%mod)%mod;
		}
	}
	int ans=0;
	for(register int i=0;i<=s1;i++) {
		(ans+=(int64)f[s1-i][s2]*fact[i]%mod*fact[i]%mod*C(s1,i)%mod*C(s1+s2,i)%mod)%=mod;
	}
	printf("%d\n",ans);
	return 0;
}
posted @ 2018-06-11 14:38  skylee03  阅读(110)  评论(0编辑  收藏  举报