[AGC019E] Shuffle and Swap
tag:组合计数,生成函数,多项式快速幂
蹲坑想出来的大体思路(雾
为了方便表述,下面这种形式
\[\begin{matrix}A_{x_1}&A_{x_2}&\cdots&A_{x_n}\\B_{y_1}&B_{y_2}&\cdots&B_{y_n}\end{matrix}
\]
表示一串交换操作 \((x_1,x_2)(x_2,x_3)\cdots(x_{n-1},x_n)\)
对于一次交换操作,显然可以把这两列单独拿出来(根据题目,一定满足其中一列的 \(B=1\),另外一列的 \(A=1\))
\[\begin{matrix}x&1\\1&y\end{matrix}
\]
考虑枚举另外两个位置的所有可能性,会发现:
- \(x=y=0\) 时,交换后 \(A\) 和 \(B\) 相同的位置+1
- 其余情况,相同位置数不变
定义一串操作为有效操作,当且仅当依次执行完以后相同位置数+1。
显然
\[\begin{matrix}0&1\\1&0\end{matrix}
\]
是一串有效操作
进一步有形如:
\[\begin{matrix}0&1&\cdots&1&1\\1&1&\cdots&1&0\end{matrix}
\]
都是有效操作(中间全是 \(1\))
设有 \(p\) 个位置满足 \(A_i=1,B_i=0\)
那么初始相同位置数为 \(k-p\),然后发现一个合法的操作序列一定可以分为若干个子序列(不一定连续是因为有效操作串之间互不影响),满足刚好有 \(p\) 串有效操作,和一堆
\(\begin{matrix}1\\1\end{matrix}\) 操作。
对于一个包含 \(x\) 个中间元素的有效操作串,贡献为 \(x!\) 因为中间元素可以是任意顺序。
然后就可以枚举被当作中间元素的 \(\begin{matrix}1\\1\end{matrix}\) 操作的个数(显然这部分贡献只与个数有关),然后再枚举每一种分配方案,求出贡献和。再乘上一些组合数就可以求出答案。
\[ans=p!\cdot\sum_{x=0}^{k-p}\sum_{\sum d_i=x}(\binom{k-p}{d_1\quad d_2\quad\cdots\quad d_k}\binom{k}{d_1+1\quad d_2+1\quad\cdots\quad d_k+1}\Pi(d_i)!)
\]
\[ans=p!\cdot(k-p)!\cdot k!\sum_{x=0}^{k-p}\sum_{\sum d_i=x}\frac1{(d_i+1)!}
\]
后面的和式可以写成生成函数的形式
\[\sum_{i=0}^{k-p}[x^i]f^p(x)
\]
\[f(x)=\sum \frac1{(d_i+1)!}x^i
\]
\(f^p(x)\) 可以用多项式快速幂求出。
复杂度 \(O(nlogn)\)
快速幂部分是复制的
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=0;
while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48))%998244353);
if(flag)n=-n;
}
typedef long long ll;
const ll MOD = 998244353;
const ll G = 3;
const ll invG = (MOD+1)/G;
const int MAXN = 10005;
inline ll ksm(ll base, ll k=MOD-2){
ll res=1;
while(k){
if(k&1)
res=res*base%MOD;
base=base*base%MOD;
k>>=1ll;
}
return res;
}
ll invN;
int tr[MAXN<<2];
inline int Make(int len){
int L = 1<<((int)(log(len)/log(2))+1); invN = ksm(L);
for(register int i=0; i<L; i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(L>>1):0);
return L;
}
ll jc[MAXN], invjc[MAXN], inv[MAXN];
inline void ntt(ll *f, int len, int flag){
for(register int i=0; i<len; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
for(register int k=2; k<=len; k<<=1){
int L=k>>1; ll base=ksm(flag==1?G:invG,(MOD-1)/k);
for(register int l=0; l<len; l+=k){
ll now=1;
for(register int p=l; p<l+L; p++){
ll tmp=f[p+L]*now%MOD;
f[p+L]=f[p]-tmp; if(f[p+L]<0) f[p+L]+=MOD;
f[p]=f[p]+tmp; if(f[p]>=MOD) f[p]-=MOD;
now=now*base%MOD;
}
}
}
if(flag==-1) for(register int i=0; i<len; i++) f[i]=f[i]*invN%MOD;
}
inline void poly_inv(ll *f, ll *res, int len){
static ll tmp[MAXN<<2];
if(len==1) return void(res[0]=ksm(f[0]));
poly_inv(f,res,(len+1)/2);
copy(f,f+len,tmp);
int L=Make(len+len);
ntt(tmp,L,1); ntt(res,L,1);
for(register int i=0; i<L; i++) res[i]=(res[i]+res[i]-tmp[i]*res[i]%MOD*res[i]%MOD+MOD)%MOD;
ntt(res,L,-1);
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
}
inline void poly_ln(ll *f, ll *res, int len){
static ll tmp[MAXN<<2];
for(register int i=1; i<len; i++) res[i-1]=f[i]*i%MOD;
poly_inv(f,tmp,len);
int L=Make(len+len);
ntt(res,L,1); ntt(tmp,L,1);
for(register int i=0; i<L; i++) res[i]=res[i]*tmp[i]%MOD;
ntt(res,L,-1);
for(register int i=len-1; i>=1; i--) res[i]=res[i-1]*inv[i]%MOD; res[0]=0;
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
}
inline void poly_exp(ll *f, ll *res, int len){
static ll tmp[MAXN<<2], Log[MAXN<<2];
if(len==1) return void(res[0]=1);
poly_exp(f,res,(len+1)/2);
copy(f,f+len,tmp);
poly_ln(res,Log,len);
int L=Make(len*1.5);
ntt(res,L,1); ntt(tmp,L,1); ntt(Log,L,1);
for(register int i=0; i<L; i++) res[i]=(1ll-Log[i]+tmp[i]+MOD)%MOD*res[i]%MOD;
ntt(res,L,-1);
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
fill(Log,Log+L,0);
}
inline void poly_qpow(ll *f, ll *res, ll k, int len){
static ll tmp[MAXN<<2];
poly_ln(f,tmp,len);
for(register int i=0; i<len; i++) tmp[i]=tmp[i]*k%MOD;
poly_exp(tmp,res,len);
}
ll f[MAXN<<2], g[MAXN<<2];
char a[MAXN], b[MAXN];
int main(){
// freopen("1.in","r",stdin);
scanf("%s%s",a+1,b+1);
int A=0, k=0, n=strlen(a+1);
for(register int i=1; i<=n; i++)
if(a[i]=='1') k++, A += (b[i]=='0');
jc[0]=1; for(register int i=1; i<=n; i++) jc[i]=jc[i-1]*(ll)i%MOD, inv[i]=ksm(i);
invjc[n]=ksm(jc[n]); invjc[0]=1;
for(register int i=n-1; i>=1; i--) invjc[i]=invjc[i+1]*(ll)(i+1)%MOD;
for(register int i=0; i<=k-A; i++) f[i] = 1ll*invjc[i+1]%MOD;
poly_qpow(f,g,A,k-A+1);
int ans=0;
for(register int i=0; i<=k-A; i++)
ans = (ans+1ll*jc[A]*jc[k-A]%MOD*g[i])%MOD;
cout<<1ll*ans*jc[k]%MOD<<endl;
return 0;
}