题目链接

https://atcoder.jp/contests/agc019/tasks/agc019_e

题解

tourist的神仙E题啊做不来做不来……这题我好像想歪了啊= =……

首先我们可以考虑,什么样的操作序列才是合法的?
有用的位置只有两种,一种是两个序列在这个位置上都是1, 称作11型,另一种是一个0一个1, 称作01型。设两种位置分别有\(A\)个和\(2B\)个。
考虑一个操作序列,交换两个11型相当于没交换,每个11型只会被交换两次,每个01型只会被交换一次。这也就是说,如果我们从shuffle之后的\(a_i\)\(b_i\)连边,那么形成的图一定是若干个环加上\(m\)条链,链的开头结尾都是01型,中间是11型。对于链来说,上面操作的顺序必须固定;对于环来说,上面操作的顺序可以任意。

下面有两种处理方式。

做法一

\(dp[i][j]\)表示把\(j\)个无标号的11型放到\(i\)条链中,可得DP式: \(dp[i][j]=\sum_{k\ge 0}\frac{dp[i-1][j-k]}{(k+1)!}\), 其中分母的含义是链上\((k+1)\)个点顺序固定,最后的答案是\(A!B!(A+B)!\sum^A_{i=0}dp[B][i]\). \((A+B)!\)表示将边随意排序,\(A!B!\)表示11型和01型点之间是有标号的。
时间复杂度\(O(n^3)\).

但是我们发现这个DP就相当于在给多项式\(\sum_{n\ge 0}\frac{1}{(n+1)!}x^n\)进行幂运算,于是用多项式快速幂加速即可,时间复杂度\(O(n\log^2n)\)\(O(n\log n)\).

做法二

有没有聪明一点的做法?有!
\(dp[i][j]\)表示目前一共放了\(i\)个11型和\(j\)个01型链(考虑已经放了的元素的标号,但是每次仅仅是往右添加),我们强行转移!
\(dp[i][j]=j^2\times dp[i][j-1]+ij\times dp[i-1][j]\)
前一个式子是要加一个新的01型链,选两个01型;后一个式子是要选一个链,再把这条链的结尾端点任意扩展一个位置。
答案就是\(\sum_{k}{A\choose k}\times f[k][B]\times ((A-k)!)^2\times {A+B\choose A-k}\), 其中\(A+B\choose A-k\)是选出位置,\(A\choose k\)是选出编号,\(((A-k)!)^2\)是求出组成环的方案数。
时间复杂度\(O(n^2)\), 可以通过。

但是我们发现这个DP还可以用多项式优化!
\(g[i][j]=\frac{dp[i][j]}{(j!)^2i!}\), 显然有\(g[i][j]=g[i][j-1]+j\times g[i-1][j]\)
然后这个使用NE Lattice Path的方式来理解,就是从\((0,0)\)走到\((i,j)\),每往上走一次路径权值乘上横坐标,求所有路径权值和。
考虑另一种DP,枚举在第\(i\)列走几步,那么发现第\(i\)列的生成函数就是\(\frac{1}{1-ix}\), 然后答案就是所有列生成函数之积
于是可以分治NTT+多项式求逆计算,时间复杂度\(O(n\log^2n)\).

代码

做法二\(O(n^2)\)

#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<cassert>
#include<cstring>
#define llong long long
using namespace std;

const int N = 2e4;
const int P = 998244353;
const llong INV2 = 499122177;
llong fact[N+3],finv[N+3];

llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
llong comb(llong x,llong y) {return x<0||y<0||x<y ? 0ll : fact[x]*finv[y]%P*finv[x-y]%P;}

int n,a,b;
char s[N+3],t[N+3];
llong dp[2][N+3];

int main()
{
	fact[0] = 1ll; for(int i=1; i<=N; i++) fact[i] = fact[i-1]*i%P;
	finv[N] = quickpow(fact[N],P-2); for(int i=N-1; i>=0; i--) finv[i] = finv[i+1]*(i+1)%P;
	scanf("%s%s",s+1,t+1); n = strlen(s+1);
	for(int i=1; i<=n; i++)
	{
		if(s[i]=='1' && t[i]=='1') {a++;}
		else if(s[i]^t[i]) {b++;}
	}
	b>>=1;
	int cur = 0,prv = 1;
	dp[0][0] = 1ll;
	for(int j=1; j<=b; j++)
	{
		cur^=1; prv^=1;
		dp[cur][0] = dp[prv][0]*j*j%P;
		for(int i=1; i<=a; i++)
		{
			dp[cur][i] = (dp[prv][i]*j*j+dp[cur][i-1]*i*j)%P;
//			printf("dp[%d][%d]=%lld\n",i,j,dp[cur][i]);
		}
	}
	llong ans = 0ll;
	for(int k=0; k<=a; k++)
	{
		ans = (ans+dp[cur][k]*comb(a,k)%P*fact[a-k]%P*fact[a-k]%P*comb(a+b,a-k))%P;
//		printf("k%d ans%lld\n",k,ans);
	}
	printf("%lld\n",ans);
	return 0;
}