[CF794G] Replace All【组合数学】【数论】
哎,我是不是有场 CF 没写题解,咕了咕了(
先考虑对于没有 ?
的情况,即已知了一种将 \(c, d\) 中的 ?
替换为字母的方案后,如何求合法的 \(01\) 串二元组 \((s, t)\) 数量。
设 \(a_x, b_x\) 分别表示字符串 \(x\) 中 A
,B
的数量,考虑以下两种情况:
- \(a_c< a_d\) 且 \(b_c> b_d\)。
- \(a_c>a_d\) 且 \(b_c<b_d\)。
- \(a_c=a_d\) 且 \(b_c=b_d\)。
注意到 \(|s|, |t|\geq 1\),因此对于这三种以外的情况,无论 \(|s|, |t|\) 如何安排,都不可能让替换完的 \(01\) 序列长度相等。而前两种情况实际上是对称的,因此这里只考虑第一种。
为了让替换完的 \(01\) 序列长度相等,可以得到 \((a_d-a_c)|s|=(b_c-b_d)|t|\)。设 \(g=\gcd(a_d-a_c, b_c-b_d)\),不难发现,\(|s|\) 的最小值 \(m_s=\frac{b_c-b_d}{g}\),对应的 \(|t|\) 的最小值 \(m_t=\frac{a_d-a_c}{g}\),不妨设 \(m_t\leq m_s\)。
由于 \(\gcd(m_s, m_t)=1\),不难发现,任意的 \(x\in [0, m_s)\),都有唯一的 \(y\in [0, m_s)\),使得 \(y\cdot m_t\equiv x\pmod {m_s}\)。换句话说就是,对于一个连续的 \(m_t\) 个 \(s\) 组成的序列,我们不断地从最前面截取长度为 \(m_t\) 的一段,截取下来的恰好是 \(s\) 的每一个循环(即,将 \(s\) 的某个前缀取下来,原封不动地接到 \(s\) 的后面)。不难发现,在这种情况下,\(s, t\) 只有两种方案:全由 \(0\) 组成,或全由 \(1\) 组成。不难发现,如果将这个序列复制正整数遍,再在某些位置插入一些 \(s\) 或 \(t\),只要在截取的时候将这些串单独截取掉,上面的性质仍然成立。
上面一段的最后一句话,实际上表示了,对于任意一种满足情况 \(1\) 的(情况 \(2\) 类似),确定了 \(a_c, a_d, b_c, b_d\) 的 \(c, d\) 串的安排方式,我们只关心 \(a_d-a_c\) 和 \(b_c-b_d\) 的值,而并不关心这些 A
,B
的具体位置。并且,我们能以此算出对应的 \(s, t\) 的方案数。具体来说,首先 \(|s|, |t|\) 分别是 \(m_s, m_t\) 的倍数,且 \(|s|=k\cdot m_s, |t|=k\cdot m_t\) 时,我们有 \(2^k\) 种方案(因为不互质的部分无法用上面的方式取到,所以这些位置恰好分成了 \(k\) 个独立取值的连通块)。这是一个简单的等比数列求和。
因此,我们可以暴力枚举 \(a_d-a_c\) 的取值。不难发现,由于 ?
在 \(c, d\) 中的数量是确定的,我们可以直接算出 \(b_c-b_d\) 的值。还有一个问题是,可能会有多组 \((a_c, a_d, b_c, b_d)\)。但是感性理解一下可以发现,如果要保持 \(a_d-a_c\) 和 \(b_c-b_d\) 不变,那么每在 \(c\) 中多将一个 ?
替换成 A
,就必须在 \(d\) 中少将一个 ?
替换成 B
。换句话说,\(a_c+b_d\) 的值是确定的,并且这种情况下,将 ?
替换为字母的方案数是一个只与 \(a_c+b_d\) 相关的组合数(实际上它就是范德蒙德恒等式)。
现在还剩情况 \(3\)。对于情况 \(3\),发现 \(|s|, |t|\) 已经没有了限制,因此我们要对每一个 \((|s|, |t|)\) 求出合法的 \(s, t\) 数量。实际上,与上面的讨论差不多,如果 \(\gcd(|s|, |t|)=1\),可以证明 \(s, t\) 也是要么全 \(0\),要么全 \(1\) 的。拓展一下,可以发现我们要求的就是 \(\sum_{|s|=1}^{n}\sum_{|t|=1}^n 2^{\gcd(|s|, |t|)}\)。这是非常简单的莫反,此处不再赘述。而将 ?
替换为字母的方案数,实际上与上面是相同的。
但是,这里还有最后一个坑点。如果 \(c=d\),也就是说每个位置的 A/B
都相等,那么任意一组长度不超过 \(n\) 的 \((s, t)\) 显然都是满足条件的。因此这一部分要从上面扣除,单独计算。
Code:
#include <bits/stdc++.h>
#define R register
#define mp make_pair
#define ll long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7, N = 310000, M = N << 1;
int n, m, k, sa, sb, ta, tb, sc, tc, ispr[N], mu[N];
ll fac[M], inv[M];
char s[N], t[N];
vector<int> prime;
inline int addMod(int a, int b) {
return (a += b) >= mod ? a - mod : a;
}
inline ll quickpow(ll base, ll pw) {
ll ret = 1;
while (pw) {
if (pw & 1) ret = ret * base % mod;
base = base * base % mod, pw >>= 1;
}
return ret;
}
template <class T>
inline void read(T &x) {
x = 0;
char ch = getchar(), w = 0;
while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
x = w ? -x : x;
return;
}
inline void initComb(int n) {
fac[0] = 1;
for (R int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % mod;
inv[n] = quickpow(fac[n], mod - 2);
for (R int i = n; i; --i) inv[i - 1] = inv[i] * i % mod;
}
inline ll comb(int n, int m) {
if (m < 0 || n < m) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int getGcd(int a, int b) {
return b ? getGcd(b, a % b) : a;
}
inline ll calc(int x) {
return addMod(quickpow(2, x + 1), mod - 2);
}
inline int sign(int x) {
return x < 0 ? -1 : x > 0;
}
void initPrime(int n) {
mu[1] = 1;
for (R int i = 2, k; i <= n; ++i) {
if (!ispr[i])
mu[i] = -1, prime.push_back(i);
for (auto &j : prime) {
if ((k = i * j) > n) break;
ispr[k] = 1;
if (i % j == 0) break;
mu[k] = addMod(mod, -mu[i]);
}
}
return;
}
inline ll sq(ll x) {
return x * x % mod;
}
int main() {
scanf("%s%s", s + 1, t + 1), read(k);
n = strlen(s + 1), m = strlen(t + 1);
if (n < m) swap(s, t), swap(n, m);
initComb(n + m), initPrime(k);
for (R int i = 1; i <= n; ++i)
sa += s[i] == 'A', sb += s[i] == 'B', sc += s[i] == '?';
for (R int i = 1; i <= m; ++i)
ta += t[i] == 'A', tb += t[i] == 'B', tc += t[i] == '?';
ll ans = 0;
for (R int i = sa - ta - tc; i <= sa - ta + sc; ++i) {
int j = m - n + i;
if (i == 0 && j == 0) {
ll w = 0, pw = 1;
for (R int d = 1; d <= k; ++d) {
pw = addMod(pw, pw);
for (R int u = 1, v = d; v <= k; ++u, v += d)
w = (w + pw * mu[u] % mod * sq(k / v)) % mod;
}
pw = 1;
for (R int d = 1; pw && d <= n; ++d) {
if (s[d] == '?' && t[d] == '?')
pw = addMod(pw, pw);
else if (s[d] != '?' && t[d] != '?' && s[d] != t[d])
pw = 0;
}
ans = (ans + w * (comb(sc + tc, ta + tc - sa + i) + mod - pw)) % mod;
ans = (ans + pw * sq(calc(k))) % mod;
}
if (sign(i) * sign(j) != 1) continue;
ans = (ans + calc(k / ((i < 0 ? j : i) / getGcd(i, j))) * comb(sc + tc, ta + tc - sa + i)) % mod;
}
cout << ans << endl;
return 0;
}