CF794G Replace All【分析性质,计数】
给定两个 \(\texttt{AB?}\) 字符串 \(c,d\) 和正整数 \(n\),求在所有将 \(\texttt ?\) 替换为 \(\texttt{A/B}\) 的方案中,满足 \(1\le |S|,|T|\le n\),将 \(c,d\) 的 \(\texttt A\) 替换为 \(S\),将 \(\texttt B\) 替换为 \(T\) 使得 \(c=d\) 的 \(\texttt{01}\) 字符串对 \((S,T)\) 的个数之和\(\bmod(10^9+7)\)。
\(|c|,|d|,n\le 3\cdot 10^5\)
手玩一下,发现当 \(c=\texttt{AB}\),\(d=\texttt{BA}\) 时 \(S,T\) 都有长为 \(\gcd(|S|,|T|)\) 的整周期。
结论:当 \(c\ne d\) 时 \(S,T\) 都有长为 \(\gcd(|S|,|T|)\) 的整周期。
证明:考虑对 \(|S|+|T|\) 归纳,显然当 \(|S|=|T|=1\) 时结论成立。否则不妨设 \(|S|\le |T|\),若 \(|S|=|T|\) 则显然 \(S=T\),若 \(|S|<|T|\) 则把 \(c,d\) 的 lcp 次掉之后,不妨设 \(c_1=\texttt A,d_1=\texttt B\),则 \(S\) 是 \(T\) 的前缀,设 \(T=S+T'\),并将 \(c,d\) 中的 \(\texttt B\) 替换为 \(\texttt{AB}\),此时 \(|S|+|T|\) 更小了,且 \(c_2\ne d_2\) 所以 \(c\ne d\),由归纳假设可知 \(S,T'\) 都有长为 \(\gcd(|S|,|T'|)\),则结论显然成立,得证。
所以 \(c,d\) 中的顺序无关紧要,设 \(a\) 是 \(c\) 的 \(\texttt{A}\) 个数减去 \(d\) 的 \(\texttt A\) 个数,\(b\) 是 \(d\) 的 \(\texttt B\) 个数减去 \(a\) 的 \(\texttt B\) 个数。
显然有 \(a\cdot|S|=b\cdot|T|\),配合上结论即为充要条件。
当 \(a=b=0\) 时
后面的和式只与 \(\lfloor n/p\rfloor\) 有关,整除分块即可 \(O(n)\)。
当然有一种特殊情况:\(c=d\) 时答案为 \((\sum_{i=1}^n2^i)^2\)。
此外,只有当 \(ab>0\) 时才有解,答案为 \(\sum_{i=1}^r2^i\),其中 \(r=\lfloor\frac{n\gcd(a,b)}{\max(a,b)}\rfloor\)。
然后要考虑有问号的情况:设 \(x,y\) 表示 \(c,d\) 的 \(\texttt ?\) 数量,\(f_{a,b}\) 表示上述答案的值,则
那么就做完了,时间复杂度 \(O(n+|c|+|d|)\)。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 3e5+3, M = 26000, mod = 1e9+7;
int ksm(int a, int b){
int r = 1;
for(;b;b >>= 1, a = (LL)a * a % mod)
if(b & 1) r = (LL)r * a % mod;
return r;
}
void qmo(int &x){x += x >> 31 & mod;}
int n, tot, nc, nd, res, ans, a, b, x, y, fac[N<<1], inv[N<<1], pw[N], mu[N], pri[M];
char c[N], d[N];
bool notp[N];
int calc(int n){
int res = 0;
for(int l = 1, r, x;l <= n;l = r+1){
r = n / (x = n/l);
res = (res + ((LL)mu[r]-mu[l-1]+mod)*x%mod*x)%mod;
} return res;
}
int F(int a, int b){
if(!a && !b) return res;
if((LL)a*b <= 0) return 0;
if(a < 0){a = -a; b = -b;}
return pw[n/(max(a,b)/__gcd(a,b))];
}
int main(){
scanf("%s%s%d", c, d, &n);
nc = strlen(c); nd = strlen(d);
for(int i = 0;i < nc;++ i)
switch(c[i]){
case 'A': ++ a; break;
case 'B': -- b; break;
case '?': ++ x;
}
for(int i = 0;i < nd;++ i)
switch(d[i]){
case 'A': -- a; break;
case 'B': ++ b; break;
case '?': ++ y;
}
pw[1] = 2;
for(int i = 2;i <= n;++ i) qmo(pw[i] = (pw[i-1]<<1) - mod);
for(int i = 2;i <= n;++ i) qmo(pw[i] += pw[i-1] - mod);
fac[0] = mu[1] = 1;
for(int i = 1;i <= x+y;++ i) fac[i] = (LL)fac[i-1] * i % mod;
inv[x+y] = ksm(fac[x+y], mod-2);
for(int i = x+y;i;-- i) inv[i-1] = (LL)inv[i] * i % mod;
notp[0] = notp[1] = true;
for(int i = 2;i <= n;++ i){
if(!notp[i]) mu[pri[tot++] = i] = -1;
for(int j = 0;j < tot && i * pri[j] < N;++ j){
notp[i*pri[j]] = true;
if(i % pri[j]) mu[i*pri[j]] = -mu[i];
else break;
}
}
for(int i = 2;i <= n;++ i) qmo(mu[i] += mu[i-1]);
for(int l = 1, r, x;l <= n;l = r+1){
r = n / (x = n/l);
res = (res + ((LL)pw[r]-pw[l-1]+mod)*calc(x)) % mod;
}
for(int i = 0;i <= x+y;++ i)
ans = (ans + (LL)inv[i]*inv[x+y-i]%mod*F(a-y+i,b-x+i))%mod;
ans = (LL)ans * fac[x+y] % mod;
if(nc == nd){
int tmp = 1;
for(int i = 0;i < nc;++ i)
if(c[i] == '?' && d[i] == '?') qmo(tmp = (tmp<<1) - mod);
else if(c[i] != '?' && d[i] != '?' && c[i] != d[i]){tmp = 0; break;}
ans = (ans + ((LL)pw[n]*pw[n]+mod-res)%mod*tmp)%mod;
}
printf("%d\n", ans);
}