题解:「PA 2019」Trzy kule

总方案为 \(2^n\) 种,首先有补集转化,求三个条件都不满足的个数。

为了方便,下文中 \(d(a,b) \gets \sum_{i=1}^n [a_i = b_i]\)\(r_i \gets n - r_i - 1\)

首先,若设 \(s_{i,j} \gets [s_{1,j} = s_{i,j}]\) ,则三元组 \((s_{1, i}, s_{2, i}, s_{3, i})\) 只有 \((1,1,1)\ (1,1,0)\ (1,0,1)\ (1,0,0)\) 四种情况。

设这四种情况出现的个数分别为 \(k_1, k_2, k_3, k_4\) ,位置集合分别为 \(P_1, P_2, P_3, P_4\)

对于情况 \(i\) ,设 \(t_i = \sum_{j\in P_i} [S_j = s_{1,j}]\)

显然有不等式组:

\[\begin{cases} \ t_1 + t_2 + t_3 + t_4 \leq r_1 \\ \ t_1 + t_2 + k_3 - t_3 + k_4 - t_4 \leq r_2 \\ \ t_1 + k_2 - t_2 + t_3 + k_4 - t_3 \leq r_3 \\ \end{cases}\]

暴力枚举 \(t_i\) 可以获得一个 \(\mathcal{O}(n^4)\) 的算法。

在暴力算法中每次统计时,单次答案为 \({k_1 \choose t_1}{k_2 \choose t_2}\times{k_3 \choose t_3}{k_4 \choose t_4}\)

考虑 meet in the middle 的思想优化将原复杂度开根号。

\(F(i,j) = \sum_{t_1} \sum_{t_2} {k_1 \choose t_1}{k_2 \choose t_2}\) ,其中 \(t_1,t_2\) 满足 \(i = t_1 + t_2\) , \(j = t_1 + k_2 - t_2\)

将不等式移项,发现不等式中的 \(t_1+t_2 \leq r'_1\)\(t_1+k_2-t_2 \leq r'_2\)\(r'_1,r'_2\) 为两个量),可知任意一对 \((t_3,t_4)\) 对于答案的贡献即为 \({k_3 \choose t_3}{k_4 \choose t_4}\sum_{i=0}^{r'_1}\sum_{j=0}^{r'_2} F(i,j)\)

接下来对于 \(F\) 做二维前缀和处理,即可在 \(\mathcal{O}(n^2)\) 的时间内统计出答案。

Code(C++):

#include<bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s);i<=(t);++i)
#define form(i,s,t) for(register int i=(s);i>=(t);--i)
using namespace std;
const int N = 1e4+3, Mod = 1e9+7;
template<typename T>inline T Max(T A, T B) {return A>B?A:B;}
template<typename T>inline T Min(T A, T B) {return A<B?A:B;}
inline int q_pow(int p, int k) {
	int Ans = 1;
	while(k) {
		if(k & 1) Ans = 1ll * Ans * p %Mod;
		p = 1ll * p * p %Mod;
		k >>= 1;
	}
	return Ans;
}
int fac[N], inv[N];
inline void Init(int n) {
	fac[0] = 1;
	forn(i,1,n) fac[i] = 1ll * i * fac[i-1] %Mod;
	inv[n] = q_pow(fac[n], Mod - 2);
	form(i,n-1,0) inv[i] = 1ll * inv[i+1] * (i+1ll) %Mod;
}
inline int C(int n, int k) {
	return 1ll * fac[n] * inv[k] %Mod * inv[n - k] %Mod;
}
int n, r[3], k[4], sum[N][N];
char s[3][N];
int main() {
	scanf("%d", &n);
	Init(n);
	forn(i,0,2) {
		scanf("%d", r+i);
		r[i] = n - r[i] - 1;
		if(r[i] < 0) {
			return printf("%d\n", q_pow(2, n)), 0;
		}
		scanf("%s", s[i] + 1);
	}
	forn(i,1,n) {
		if(s[0][i] == s[1][i] && s[1][i] == s[2][i]) k[0] ++ ;
		else if(s[0][i] == s[1][i]) k[1] ++ ;
		else if(s[0][i] == s[2][i]) k[2] ++ ;
		else k[3] ++ ;
	}
	int Tr = 0, Tc = 0;
	forn(t0,0,k[0]) forn(t1,0,k[1]) {
		int tx = t0 + t1;
		if(tx > r[0] || tx > r[1]) continue ;
		int ty = t0 + k[1] - t1;
		if(ty > r[2]) continue ;
		Tr = Max(Tr, tx), Tc = Max(Tc, ty);
		sum[tx][ty] += 1ll * C(k[0], t0) * C(k[1], t1) %Mod;
		sum[tx][ty] %= Mod;
	}
	forn(i,0,Tr) forn(j,0,Tc) {
		i && (sum[i][j] += sum[i-1][j], sum[i][j] %= Mod);
		j && (sum[i][j] += sum[i][j-1], sum[i][j] %= Mod);
		(i && j) && (sum[i][j] += Mod - sum[i-1][j-1], sum[i][j] %= Mod);
	}
	int Ans = 0;
	forn(t2,0,k[2]) forn(t3,0,k[3]) {
		int tx = Min(r[0] - t2 - t3, r[1] - (k[2] + k[3]) + (t2 + t3));
		if(tx < 0) continue ;
		int ty = r[2] - t2 - k[3] + t3;
		if(ty < 0) continue ;
		tx = Min(tx, Tr), ty = Min(ty, Tc);
		Ans += 1ll * sum[tx][ty] * C(k[2], t2) %Mod * C(k[3], t3) %Mod;
		Ans %= Mod;
	}
	Ans = Mod - Ans, Ans += q_pow(2, n), Ans %= Mod;
	printf("%d\n", Ans);
	return 0;
}
posted @ 2021-02-25 16:15  AxDea  阅读(120)  评论(0编辑  收藏  举报