题解:「PA 2019」Trzy kule
总方案为 \(2^n\) 种,首先有补集转化,求三个条件都不满足的个数。
为了方便,下文中 \(d(a,b) \gets \sum_{i=1}^n [a_i = b_i]\) , \(r_i \gets n - r_i - 1\) 。
首先,若设 \(s_{i,j} \gets [s_{1,j} = s_{i,j}]\) ,则三元组 \((s_{1, i}, s_{2, i}, s_{3, i})\) 只有 \((1,1,1)\ (1,1,0)\ (1,0,1)\ (1,0,0)\) 四种情况。
设这四种情况出现的个数分别为 \(k_1, k_2, k_3, k_4\) ,位置集合分别为 \(P_1, P_2, P_3, P_4\)。
对于情况 \(i\) ,设 \(t_i = \sum_{j\in P_i} [S_j = s_{1,j}]\) 。
显然有不等式组:
暴力枚举 \(t_i\) 可以获得一个 \(\mathcal{O}(n^4)\) 的算法。
在暴力算法中每次统计时,单次答案为 \({k_1 \choose t_1}{k_2 \choose t_2}\times{k_3 \choose t_3}{k_4 \choose t_4}\) 。
考虑 meet in the middle 的思想优化将原复杂度开根号。
设 \(F(i,j) = \sum_{t_1} \sum_{t_2} {k_1 \choose t_1}{k_2 \choose t_2}\) ,其中 \(t_1,t_2\) 满足 \(i = t_1 + t_2\) , \(j = t_1 + k_2 - t_2\) 。
将不等式移项,发现不等式中的 \(t_1+t_2 \leq r'_1\) , \(t_1+k_2-t_2 \leq r'_2\) ( \(r'_1,r'_2\) 为两个量),可知任意一对 \((t_3,t_4)\) 对于答案的贡献即为 \({k_3 \choose t_3}{k_4 \choose t_4}\sum_{i=0}^{r'_1}\sum_{j=0}^{r'_2} F(i,j)\) 。
接下来对于 \(F\) 做二维前缀和处理,即可在 \(\mathcal{O}(n^2)\) 的时间内统计出答案。
Code(C++):
#include<bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s);i<=(t);++i)
#define form(i,s,t) for(register int i=(s);i>=(t);--i)
using namespace std;
const int N = 1e4+3, Mod = 1e9+7;
template<typename T>inline T Max(T A, T B) {return A>B?A:B;}
template<typename T>inline T Min(T A, T B) {return A<B?A:B;}
inline int q_pow(int p, int k) {
int Ans = 1;
while(k) {
if(k & 1) Ans = 1ll * Ans * p %Mod;
p = 1ll * p * p %Mod;
k >>= 1;
}
return Ans;
}
int fac[N], inv[N];
inline void Init(int n) {
fac[0] = 1;
forn(i,1,n) fac[i] = 1ll * i * fac[i-1] %Mod;
inv[n] = q_pow(fac[n], Mod - 2);
form(i,n-1,0) inv[i] = 1ll * inv[i+1] * (i+1ll) %Mod;
}
inline int C(int n, int k) {
return 1ll * fac[n] * inv[k] %Mod * inv[n - k] %Mod;
}
int n, r[3], k[4], sum[N][N];
char s[3][N];
int main() {
scanf("%d", &n);
Init(n);
forn(i,0,2) {
scanf("%d", r+i);
r[i] = n - r[i] - 1;
if(r[i] < 0) {
return printf("%d\n", q_pow(2, n)), 0;
}
scanf("%s", s[i] + 1);
}
forn(i,1,n) {
if(s[0][i] == s[1][i] && s[1][i] == s[2][i]) k[0] ++ ;
else if(s[0][i] == s[1][i]) k[1] ++ ;
else if(s[0][i] == s[2][i]) k[2] ++ ;
else k[3] ++ ;
}
int Tr = 0, Tc = 0;
forn(t0,0,k[0]) forn(t1,0,k[1]) {
int tx = t0 + t1;
if(tx > r[0] || tx > r[1]) continue ;
int ty = t0 + k[1] - t1;
if(ty > r[2]) continue ;
Tr = Max(Tr, tx), Tc = Max(Tc, ty);
sum[tx][ty] += 1ll * C(k[0], t0) * C(k[1], t1) %Mod;
sum[tx][ty] %= Mod;
}
forn(i,0,Tr) forn(j,0,Tc) {
i && (sum[i][j] += sum[i-1][j], sum[i][j] %= Mod);
j && (sum[i][j] += sum[i][j-1], sum[i][j] %= Mod);
(i && j) && (sum[i][j] += Mod - sum[i-1][j-1], sum[i][j] %= Mod);
}
int Ans = 0;
forn(t2,0,k[2]) forn(t3,0,k[3]) {
int tx = Min(r[0] - t2 - t3, r[1] - (k[2] + k[3]) + (t2 + t3));
if(tx < 0) continue ;
int ty = r[2] - t2 - k[3] + t3;
if(ty < 0) continue ;
tx = Min(tx, Tr), ty = Min(ty, Tc);
Ans += 1ll * sum[tx][ty] * C(k[2], t2) %Mod * C(k[3], t3) %Mod;
Ans %= Mod;
}
Ans = Mod - Ans, Ans += q_pow(2, n), Ans %= Mod;
printf("%d\n", Ans);
return 0;
}