LOJ3223 Trzy kule
Trzy kule
对于两个长度为 \(n\) 的 \(01\) 串 \(a_1, a_2, \dots, a_n\) 和 \(b_1, b_2, \dots, b_n\),定义它们的距离 \(d(a, b) = \sum_{i=1}^{n} |a_i - b_i|\)。
给定三个长度为 \(n\) 的 \(01\) 串 \(s_1, s_2, s_3\) 以及三个非负整数 \(r_1, r_2, r_3\),问有多少个长度为 \(n\) 的 \(01\) 串 \(S\) 满足 \(d(S, s_1) \le r_1, d(S, s_2) \le r_2, d(S, s_3) \le r_3\) 这三个不等式中至少有一个成立。
\(1 \le n \le 10000, 0 \le r_i \le n\)
题解
http://jklover.hs-blog.cf/2020/06/09/Loj-3223-Trzy-kule/#more
Meet in the middle + 二维前缀和.
首先可以补集转化一下,用 \(2^n\) 减掉三个不等式都不成立的方案数.
将每个 \(r\) 变为 \(n-1-r\) ,则我们需要求出与每个串相同的字符数都不超过对应的 \(r\) 的方案数.
首先可以将这三个串都异或上第一个串,答案不变,于是一个位置只会有 \(000,001,010,011\) 这四种情况.
记这四种情况分别有\(k_0,k_1,k_2,k_3\)个,再记串 \(S\) 中,这四种位置上 \(0\) 的数目分别为 \(c_0,c_1,c_2,c_3\).
显然要满足限制:
\(c_0+c_1+c_2+c_3\leq r_0\)。
\(c_0+c_1+k_2-c_2+k_3-c_3\leq r_1\)。
\(c_0+k_1-c_1+c_2+k_3-c_3\leq r_3\)。
考虑 Meet in the middle, 先枚举 \(c_0,c_1\) ,再枚举 \(c_2,c_3\) ,询问能凑成合法四元组的 \(c_0,c_1\) 的贡献总和.
看上去有 3 个维度的限制,但是 \(c_0,c_1\) 对应的点中有 2 维是一样的,于是可以缩成 2 维,变成单点加,最后矩形求和.
由于每个维度的坐标都不会超过 \(n\) ,所以直接用二维前缀和处理即可.
时间复杂度\(O(n^2)\)。
CO int N=1e4+10;
int fac[N],ifac[N],lim[3];
char str[3][N];
int cnt[4],iter[4];
int sum[N][N];
IN int C(int n,int m){
return mul(fac[n],mul(ifac[m],ifac[n-m]));
}
int main(){
int n=read<int>();
fac[0]=1;
for(int i=1;i<=n;++i) fac[i]=mul(fac[i-1],i);
ifac[n]=fpow(fac[n],mod-2);
for(int i=n-1;i>=0;--i) ifac[i]=mul(ifac[i+1],i+1);
for(int i=0;i<3;++i){
lim[i]=n-1-read<int>();
if(lim[i]<0){ // r=n
printf("%d\n",fpow(2,n));
return 0;
}
scanf("%s",str[i]+1);
}
fac[0]=1;
for(int i=1;i<=n;++i){
int x=str[0][i]-'0',y=str[1][i]-'0',z=str[2][i]-'0';
y^=x,z^=x;
cnt[y<<1|z]++; // notice
}
int w=0,h=0; // edit 1: TLE
for(iter[0]=0;iter[0]<=cnt[0];++iter[0])for(iter[1]=0;iter[1]<=cnt[1];++iter[1]){
int x=iter[0]+iter[1]; // matching for str[0] and str[1]
if(x>lim[0] or x>lim[1]) continue;
int y=iter[0]+cnt[1]-iter[1]; // matching for str[2]
if(y>lim[2]) continue;
w=max(w,x),h=max(h,y);
sum[x][y]=add(sum[x][y],mul(C(cnt[0],iter[0]),C(cnt[1],iter[1])));
}
for(int i=0;i<=w;++i)for(int j=0;j<=h;++j){
if(i) sum[i][j]=add(sum[i][j],sum[i-1][j]);
if(j) sum[i][j]=add(sum[i][j],sum[i][j-1]);
if(i and j) sum[i][j]=add(sum[i][j],mod-sum[i-1][j-1]);
}
int ans=0;
for(iter[2]=0;iter[2]<=cnt[2];++iter[2])for(iter[3]=0;iter[3]<=cnt[3];++iter[3]){
int x=min(lim[0]-iter[2]-iter[3],lim[1]-(cnt[2]-iter[2])-(cnt[3]-iter[3])); // limit for str[0] and str[1]
if(x<0) continue;
int y=lim[2]-iter[2]-(cnt[3]-iter[3]); // limit for str[2]
if(y<0) continue;
x=min(x,w),y=min(y,h);
ans=add(ans,mul(sum[x][y],mul(C(cnt[2],iter[2]),C(cnt[3],iter[3]))));
}
ans=add(fpow(2,n),mod-ans);
printf("%d\n",ans);
return 0;
}