[AGC019E]Shuffle and Swap
[AGC019E]Shuffle and Swap
题目大意:
给出两个长度为\(n(n\le10000)\)的\(01\)串\(A_{1\sim n}\)和\(B_{1\sim n}\)。两个串均有\(k\)个'1'
。令\(a_{1\sim k}\)和\(b_{1\sim k}\)分别表示\(A\)和\(B\)中所有'1'
出现的位置。将\(a\)和\(b\)等概率随机排列,按\(1\sim k\)的顺序交换\(A_{a_i}\)和\(A_{b_i}\)。令\(P\)表示操作完成后\(A\)与\(B\)相等的概率,求\(P\times(k!)^2\)在模\(998244353\)意义下的值。
思路:
将\(k\)个'1'
分为两类:\(S_1\)和\(S_2\)。\(S_1\)表示\(A\)和\(B\)中均为'1'
,有\(s_1\)个;\(S_2\)表示\(A\)和\(B\)中只有一个为'1'
,有\(s_2\)个。用\(f[i][j]\)表示\(S_1\)交换完\(i\)个,\(S_2\)交换完\(j\)个后的方案数。
显然,初始情况下,\(S_1\)可以交换\(2\)次,\(S_2\)可以交换\(1\)次。方便起见,假设我们的\(A\)和\(B\)分别为:
110011
111100
对于\(i=0\)的初始情况,由于\(S_2\)之间任意配对都可以使得\(A=B\),因此有\((j!)^2\)种方案。
若我们交换一对\(S_2\),如\(4\)和\(5\),那么交换后,可以交换\(1\)次的\(S_2\)减少了一对。有转移\(f[i][j]+=f[i][j-1]\times j^2\)。
若我们交换一个\(S_1\)和一个\(S_2\),如\(2\)和\(6\),那么交换后,原来那个\(S_2\)已经不能再交换了,而原来可以交换\(2\)次的\(S_1\)现在只能交换一次,可以算入\(S_2\)中,也就是相当于减少了一个\(S_1\)。有转移\(f[i][j]+=f[i-1][j]\times i\times j\)。
最后统计答案时,对于\(S_1\)没有用完的情况,无论如何配对都能满足\(A=B\)。因此\(f[s_1-i][s_2]\)对答案的贡献为\(f[s_1-i][s_2]\times(i!)^2\times\binom{s_1}i\times\binom ki\)。
时间复杂度\(\mathcal O(n^2)\)。
源代码:
#include<cstdio>
#include<cstring>
using int64=long long;
constexpr int N=10001,mod=998244353;
char a[N],b[N];
int fact[N],factinv[N],f[N][N];
void exgcd(const int &a,const int &b,int &x,int &y) {
if(!b) {
x=1,y=0;
return;
}
exgcd(b,a%b,y,x);
y-=a/b*x;
}
inline int inv(const int &x) {
int ret,tmp;
exgcd(x,mod,ret,tmp);
return (ret%mod+mod)%mod;
}
inline int C(const int &n,const int &m) {
if(n<m||n<0||m<0) return 0;
return (int64)fact[n]*factinv[m]%mod*factinv[n-m]%mod;
}
int main() {
scanf("%s%s",a,b);
const int n=strlen(a);
for(register int i=fact[0]=1;i<=n;i++) {
fact[i]=(int64)fact[i-1]*i%mod;
}
factinv[n]=inv(fact[n]);
for(register int i=n;i;i--) {
factinv[i-1]=(int64)factinv[i]*i%mod;
}
int s1=0,s2=0;
for(register int i=0;i<n;i++) {
if(a[i]=='1'&&b[i]=='1') s1++;
if(a[i]=='1'&&b[i]=='0') s2++;
}
for(register int i=0;i<=s2;i++) {
f[0][i]=(int64)fact[i]*fact[i]%mod;
}
for(register int i=1;i<=s1;i++) {
for(register int j=1;j<=s2;j++) {
f[i][j]=((int64)f[i-1][j]*i%mod*j%mod+(int64)f[i][j-1]*j%mod*j%mod)%mod;
}
}
int ans=0;
for(register int i=0;i<=s1;i++) {
(ans+=(int64)f[s1-i][s2]*fact[i]%mod*fact[i]%mod*C(s1,i)%mod*C(s1+s2,i)%mod)%=mod;
}
printf("%d\n",ans);
return 0;
}