[折半分治]luogu P5982 [PA2019]Trzy kule
题面
https://www.luogu.com.cn/problem/P5982
给定一01串,问有多少个01串与他的不同位数满足给定的限制,这种01串总共有三个,满足其中一个限制即可
分析
因为只需要满足一个其中一个限制所以我们考虑计算补集,即三个都不满足
因为是01串,所以当一个01串按位取反(唯一对应)后不同位数是(原来的设为 $i$ 位不同) $n-i+1$ ,所以可以把限制改为 $\leq n-r+1 $便于计算
设三元组 $((s_{0,i}==s_{0,i}),(s_{1,i}==s_{0,i}),(s_{2,i}==s_{0,i}))$ 表示在第 i 位上三个串与第一个串同或的结果
那么该三元组显然只有 $(1,1,0),(1,1,1),(1,0,1),(1,0,0)$ 四种情况,我们设四种情况的总数分别为 $s_1,s_2,s_3,s_4$
那么当我们所枚举的01串的一位与 s0 异位时,就会变成 $(0,1,1),(0,0,0),(0,1,0),(0,1,1)$
设枚举同位三元组个数分别为 $p_1,p_2,p_3,p_4$ ,那么对应的异位三元组个数则为 $s_1-p_1,s_2-p_2,s_3-p_3,s_4-p_4$
所以限制可以变为
$p_1+p_2+p_3+p_4\leq r_0$
$p_1+p_2+s_3-p_3+s_4-p_4\leq r_1$
$s_1-p_1+p_2+p_3+s_4-p_4\leq r_2$
答案则为 $\sum_{p_1=0}^{s_1}\sum_{p_2=0}^{s_2}\sum_{p_3=0}^{s_3}\sum_{p_4=0}^{s_4} \binom{s_1}{p_1} \binom{s_2}{p_2} \binom{s_3}{p_3} \binom{s_4}{p_4}$(满足上述三个限制)
暴力枚举是 $O(n^4)$ 的,显然不行
采用折半的思想,只枚举 $p_1,p_2$
注意到此时 $p_1,p_2$ 需要满足限制
$p_1+p_2\leq min(r_0,r_1)$
$s_1-p_1+p_2\leq r_2$
所以将 $\binom{s_1}{p_1} \binom{s_2}{p_2}$ 记录在 $f[p_1+p_2][s_1-p_1+p_2]$ 中
再枚举 $p_3,p_4$ ,需要满足限制
$p_3+p_4\leq r_0$
$s_3-p_3+s_4-p_4\leq r_1$
$p_3+s_4-p_4\leq r_2$
答案则为 $\sum_{p_3=0}^{s_3}\sum_{p_4=0}^{s_4} \binom{s_3}{p_3} \binom{s_4}{p_4} \sum_{i}^{min(r_0,r_1)}\sum_{j}^{r_2} f[i][j]$ (满足上述限制)
对 $f[i][j]$ 做前缀和预处理即可优化到 $O(n^2)$
代码
#include <iostream> #include <cstdio> #include <cstring> using namespace std; typedef long long ll; const ll P=1e9+7; const int N=1e4+10; int n,r[3],n1,n2; int cnt[2][2]; ll fact[N],inv[N],f[N][N],ans; char s[3][N]; ll C(int n,int m) {return fact[n]*inv[n-m]%P*inv[m]%P;} ll Pow(ll x,ll y) {ll ans=1;for (;y;y>>=1,x=x*x%P) if (y&1) ans=ans*x%P;return ans;} int main() { scanf("%d",&n);ans=Pow(2,n); fact[0]=inv[0]=1;for (int i=1;i<=n;i++) fact[i]=fact[i-1]*i%P; inv[n]=Pow(fact[n],P-2);for (int i=n-1;i;i--) inv[i]=inv[i+1]*(i+1)%P; for (int i=0;i<3;i++) scanf("%d%s",&r[i],s[i]+1),r[i]=n-r[i]-1; for (int i=1;i<=n;i++) cnt[!(s[0][i]-'0')^(s[1][i]-'0')][!(s[0][i]-'0')^(s[2][i]-'0')]++; for (int i=0;i<=cnt[1][0];i++) for (int j=0;j<=cnt[1][1];j++) if (i+j<=min(r[0],r[1])&&cnt[1][0]-i+j<=r[2]) { n1=max(n1,i+j);n2=max(n2,cnt[1][0]-i+j); (f[i+j][cnt[1][0]-i+j]+=C(cnt[1][0],i)*C(cnt[1][1],j)%P)%=P; } for (int i=0;i<=n1;i++) for (int j=0;j<=n2;j++) (f[i][j]+=(i?f[i-1][j]:0)+(j?f[i][j-1]:0)-(i&&j?f[i-1][j-1]:0)+P)%=P; for (int i=0;i<=cnt[0][1];i++) for (int j=0;j<=cnt[0][0];j++) if (0<=min(r[0]-i-j,r[1]-(cnt[0][1]-i+cnt[0][0]-j))&&0<=r[2]-(cnt[0][0]-j+i)) (ans+=(P-f[min(min(r[0]-i-j,r[1]-(cnt[0][1]-i+cnt[0][0]-j)),n1)][min(r[2]-(cnt[0][0]-j+i),n2)]*C(cnt[0][1],i)%P*C(cnt[0][0],j)%P)%P)%=P; printf("%lld\n",ans); } //00+10+01+11<=r1 //10+11+s00-00+s01-01<=r2 //11+01+s00-00+s10-10<=r3