[AGC019E] Shuffle and Swap

tag:组合计数,生成函数,多项式快速幂


蹲坑想出来的大体思路(雾

为了方便表述,下面这种形式

\[\begin{matrix}A_{x_1}&A_{x_2}&\cdots&A_{x_n}\\B_{y_1}&B_{y_2}&\cdots&B_{y_n}\end{matrix} \]

表示一串交换操作 \((x_1,x_2)(x_2,x_3)\cdots(x_{n-1},x_n)\)


对于一次交换操作,显然可以把这两列单独拿出来(根据题目,一定满足其中一列的 \(B=1\),另外一列的 \(A=1\)

\[\begin{matrix}x&1\\1&y\end{matrix} \]

考虑枚举另外两个位置的所有可能性,会发现:

  • \(x=y=0\) 时,交换后 \(A\)\(B\) 相同的位置+1
  • 其余情况,相同位置数不变

定义一串操作为有效操作,当且仅当依次执行完以后相同位置数+1。

显然

\[\begin{matrix}0&1\\1&0\end{matrix} \]

是一串有效操作

进一步有形如:

\[\begin{matrix}0&1&\cdots&1&1\\1&1&\cdots&1&0\end{matrix} \]

都是有效操作(中间全是 \(1\)


设有 \(p\) 个位置满足 \(A_i=1,B_i=0\)

那么初始相同位置数为 \(k-p\),然后发现一个合法的操作序列一定可以分为若干个子序列(不一定连续是因为有效操作串之间互不影响),满足刚好有 \(p\)有效操作,和一堆
\(\begin{matrix}1\\1\end{matrix}\) 操作。


对于一个包含 \(x\) 个中间元素的有效操作串,贡献为 \(x!\) 因为中间元素可以是任意顺序。


然后就可以枚举被当作中间元素\(\begin{matrix}1\\1\end{matrix}\) 操作的个数(显然这部分贡献只与个数有关),然后再枚举每一种分配方案,求出贡献和。再乘上一些组合数就可以求出答案。

\[ans=p!\cdot\sum_{x=0}^{k-p}\sum_{\sum d_i=x}(\binom{k-p}{d_1\quad d_2\quad\cdots\quad d_k}\binom{k}{d_1+1\quad d_2+1\quad\cdots\quad d_k+1}\Pi(d_i)!) \]

\[ans=p!\cdot(k-p)!\cdot k!\sum_{x=0}^{k-p}\sum_{\sum d_i=x}\frac1{(d_i+1)!} \]

后面的和式可以写成生成函数的形式

\[\sum_{i=0}^{k-p}[x^i]f^p(x) \]

\[f(x)=\sum \frac1{(d_i+1)!}x^i \]

\(f^p(x)\) 可以用多项式快速幂求出。

复杂度 \(O(nlogn)\)


快速幂部分是复制的

#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=0;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
	for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48))%998244353);
	if(flag)n=-n;
}

typedef long long ll;
const ll MOD = 998244353;
const ll G = 3;
const ll invG = (MOD+1)/G;
const int MAXN = 10005;

inline ll ksm(ll base, ll k=MOD-2){
    ll res=1;
    while(k){
        if(k&1)
            res=res*base%MOD;
        base=base*base%MOD;
        k>>=1ll;
    }
    return res;
}

ll invN;
int tr[MAXN<<2];
inline int Make(int len){
    int L = 1<<((int)(log(len)/log(2))+1); invN = ksm(L);
    for(register int i=0; i<L; i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(L>>1):0);
    return L;
}

ll jc[MAXN], invjc[MAXN], inv[MAXN];

inline void ntt(ll *f, int len, int flag){
    for(register int i=0; i<len; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
    for(register int k=2; k<=len; k<<=1){
        int L=k>>1; ll base=ksm(flag==1?G:invG,(MOD-1)/k);
        for(register int l=0; l<len; l+=k){
            ll now=1;
            for(register int p=l; p<l+L; p++){
                ll tmp=f[p+L]*now%MOD;
                f[p+L]=f[p]-tmp; if(f[p+L]<0) f[p+L]+=MOD;
                f[p]=f[p]+tmp; if(f[p]>=MOD) f[p]-=MOD;
                now=now*base%MOD;
            }
        }
    }
    if(flag==-1) for(register int i=0; i<len; i++) f[i]=f[i]*invN%MOD;
}

inline void poly_inv(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2];
    if(len==1) return void(res[0]=ksm(f[0]));
    poly_inv(f,res,(len+1)/2);
    copy(f,f+len,tmp);
    int L=Make(len+len);
    ntt(tmp,L,1); ntt(res,L,1);
    for(register int i=0; i<L; i++) res[i]=(res[i]+res[i]-tmp[i]*res[i]%MOD*res[i]%MOD+MOD)%MOD;
    ntt(res,L,-1);
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
}

inline void poly_ln(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2];
    for(register int i=1; i<len; i++) res[i-1]=f[i]*i%MOD;
    poly_inv(f,tmp,len);
    int L=Make(len+len);
    ntt(res,L,1); ntt(tmp,L,1);
    for(register int i=0; i<L; i++) res[i]=res[i]*tmp[i]%MOD;
    ntt(res,L,-1);
    for(register int i=len-1; i>=1; i--) res[i]=res[i-1]*inv[i]%MOD; res[0]=0;
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
}

inline void poly_exp(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2], Log[MAXN<<2];
    if(len==1) return void(res[0]=1);
    poly_exp(f,res,(len+1)/2);
    copy(f,f+len,tmp);
    poly_ln(res,Log,len);
    int L=Make(len*1.5);
    ntt(res,L,1); ntt(tmp,L,1); ntt(Log,L,1);
    for(register int i=0; i<L; i++) res[i]=(1ll-Log[i]+tmp[i]+MOD)%MOD*res[i]%MOD;
    ntt(res,L,-1);
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
    fill(Log,Log+L,0);
}

inline void poly_qpow(ll *f, ll *res, ll k, int len){
    static ll tmp[MAXN<<2];
    poly_ln(f,tmp,len);
    for(register int i=0; i<len; i++) tmp[i]=tmp[i]*k%MOD;
    poly_exp(tmp,res,len);
}

ll f[MAXN<<2], g[MAXN<<2];

char a[MAXN], b[MAXN];

int main(){
	// freopen("1.in","r",stdin);
	scanf("%s%s",a+1,b+1);
	int A=0, k=0, n=strlen(a+1);
	for(register int i=1; i<=n; i++)
		if(a[i]=='1') k++, A += (b[i]=='0');

    jc[0]=1; for(register int i=1; i<=n; i++) jc[i]=jc[i-1]*(ll)i%MOD, inv[i]=ksm(i);
    invjc[n]=ksm(jc[n]); invjc[0]=1;
    for(register int i=n-1; i>=1; i--) invjc[i]=invjc[i+1]*(ll)(i+1)%MOD;
	for(register int i=0; i<=k-A; i++) f[i] = 1ll*invjc[i+1]%MOD;
    poly_qpow(f,g,A,k-A+1);

	int ans=0;
	for(register int i=0; i<=k-A; i++)
		ans = (ans+1ll*jc[A]*jc[k-A]%MOD*g[i])%MOD;
	cout<<1ll*ans*jc[k]%MOD<<endl;
    return 0;
}
posted @ 2021-06-26 13:30  oisdoaiu  阅读(39)  评论(0编辑  收藏  举报