[CF794G] Replace All【组合数学】【数论】

哎,我是不是有场 CF 没写题解,咕了咕了(

先考虑对于没有 ? 的情况,即已知了一种将 \(c, d\) 中的 ? 替换为字母的方案后,如何求合法的 \(01\) 串二元组 \((s, t)\) 数量。

\(a_x, b_x\) 分别表示字符串 \(x\)AB 的数量,考虑以下两种情况:

  1. \(a_c< a_d\)\(b_c> b_d\)
  2. \(a_c>a_d\)\(b_c<b_d\)
  3. \(a_c=a_d\)\(b_c=b_d\)

注意到 \(|s|, |t|\geq 1\),因此对于这三种以外的情况,无论 \(|s|, |t|\) 如何安排,都不可能让替换完的 \(01\) 序列长度相等。而前两种情况实际上是对称的,因此这里只考虑第一种。

为了让替换完的 \(01\) 序列长度相等,可以得到 \((a_d-a_c)|s|=(b_c-b_d)|t|\)。设 \(g=\gcd(a_d-a_c, b_c-b_d)\),不难发现,\(|s|\) 的最小值 \(m_s=\frac{b_c-b_d}{g}\),对应的 \(|t|\) 的最小值 \(m_t=\frac{a_d-a_c}{g}\),不妨设 \(m_t\leq m_s\)

由于 \(\gcd(m_s, m_t)=1\),不难发现,任意的 \(x\in [0, m_s)\),都有唯一的 \(y\in [0, m_s)\),使得 \(y\cdot m_t\equiv x\pmod {m_s}\)。换句话说就是,对于一个连续的 \(m_t\)\(s\) 组成的序列,我们不断地从最前面截取长度为 \(m_t\) 的一段,截取下来的恰好是 \(s\) 的每一个循环(即,将 \(s\) 的某个前缀取下来,原封不动地接到 \(s\) 的后面)。不难发现,在这种情况下,\(s, t\) 只有两种方案:全由 \(0\) 组成,或全由 \(1\) 组成。不难发现,如果将这个序列复制正整数遍,再在某些位置插入一些 \(s\)\(t\),只要在截取的时候将这些串单独截取掉,上面的性质仍然成立。

上面一段的最后一句话,实际上表示了,对于任意一种满足情况 \(1\) 的(情况 \(2\) 类似),确定了 \(a_c, a_d, b_c, b_d\)\(c, d\) 串的安排方式,我们只关心 \(a_d-a_c\)\(b_c-b_d\) 的值,而并不关心这些 AB 的具体位置。并且,我们能以此算出对应的 \(s, t\) 的方案数。具体来说,首先 \(|s|, |t|\) 分别是 \(m_s, m_t\) 的倍数,且 \(|s|=k\cdot m_s, |t|=k\cdot m_t\) 时,我们有 \(2^k\) 种方案(因为不互质的部分无法用上面的方式取到,所以这些位置恰好分成了 \(k\) 个独立取值的连通块)。这是一个简单的等比数列求和。

因此,我们可以暴力枚举 \(a_d-a_c\) 的取值。不难发现,由于 ?\(c, d\) 中的数量是确定的,我们可以直接算出 \(b_c-b_d\) 的值。还有一个问题是,可能会有多组 \((a_c, a_d, b_c, b_d)\)。但是感性理解一下可以发现,如果要保持 \(a_d-a_c\)\(b_c-b_d\) 不变,那么每在 \(c\) 中多将一个 ? 替换成 A,就必须在 \(d\) 中少将一个 ? 替换成 B。换句话说,\(a_c+b_d\) 的值是确定的,并且这种情况下,将 ? 替换为字母的方案数是一个只与 \(a_c+b_d\) 相关的组合数(实际上它就是范德蒙德恒等式)。

现在还剩情况 \(3\)。对于情况 \(3\),发现 \(|s|, |t|\) 已经没有了限制,因此我们要对每一个 \((|s|, |t|)\) 求出合法的 \(s, t\) 数量。实际上,与上面的讨论差不多,如果 \(\gcd(|s|, |t|)=1\),可以证明 \(s, t\) 也是要么全 \(0\),要么全 \(1\) 的。拓展一下,可以发现我们要求的就是 \(\sum_{|s|=1}^{n}\sum_{|t|=1}^n 2^{\gcd(|s|, |t|)}\)。这是非常简单的莫反,此处不再赘述。而将 ? 替换为字母的方案数,实际上与上面是相同的。

但是,这里还有最后一个坑点。如果 \(c=d\),也就是说每个位置的 A/B 都相等,那么任意一组长度不超过 \(n\)\((s, t)\) 显然都是满足条件的。因此这一部分要从上面扣除,单独计算。

Code:

#include <bits/stdc++.h>
#define R register
#define mp make_pair
#define ll long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7, N = 310000, M = N << 1;

int n, m, k, sa, sb, ta, tb, sc, tc, ispr[N], mu[N];
ll fac[M], inv[M];
char s[N], t[N];
vector<int> prime;

inline int addMod(int a, int b) {
	return (a += b) >= mod ? a - mod : a;
}

inline ll quickpow(ll base, ll pw) {
	ll ret = 1;
	while (pw) {
		if (pw & 1) ret = ret * base % mod;
		base = base * base % mod, pw >>= 1;
	}
	return ret;
}

template <class T>
inline void read(T &x) {
	x = 0;
	char ch = getchar(), w = 0;
	while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
	while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
	x = w ? -x : x;
	return;
}

inline void initComb(int n) {
	fac[0] = 1;
	for (R int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % mod;
	inv[n] = quickpow(fac[n], mod - 2);
	for (R int i = n; i; --i) inv[i - 1] = inv[i] * i % mod;
}

inline ll comb(int n, int m) {
	if (m < 0 || n < m) return 0;
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

int getGcd(int a, int b) {
	return b ? getGcd(b, a % b) : a;
}

inline ll calc(int x) {
	return addMod(quickpow(2, x + 1), mod - 2);
}

inline int sign(int x) {
	return x < 0 ? -1 : x > 0;
}

void initPrime(int n) {
	mu[1] = 1;
	for (R int i = 2, k; i <= n; ++i) {
		if (!ispr[i])
			mu[i] = -1, prime.push_back(i);
		for (auto &j : prime) {
			if ((k = i * j) > n) break;
			ispr[k] = 1;
			if (i % j == 0) break;
			mu[k] = addMod(mod, -mu[i]);
		}
	}
	return;
}

inline ll sq(ll x) {
	return x * x % mod;
}

int main() {
	scanf("%s%s", s + 1, t + 1), read(k);
	n = strlen(s + 1), m = strlen(t + 1);
	if (n < m) swap(s, t), swap(n, m);
	initComb(n + m), initPrime(k);
	for (R int i = 1; i <= n; ++i)
		sa += s[i] == 'A', sb += s[i] == 'B', sc += s[i] == '?';
	for (R int i = 1; i <= m; ++i)
		ta += t[i] == 'A', tb += t[i] == 'B', tc += t[i] == '?';
	ll ans = 0;
	for (R int i = sa - ta - tc; i <= sa - ta + sc; ++i) {
		int j = m - n + i;
		if (i == 0 && j == 0) {
			ll w = 0, pw = 1;
			for (R int d = 1; d <= k; ++d) {
				pw = addMod(pw, pw);
				for (R int u = 1, v = d; v <= k; ++u, v += d)
					w = (w + pw * mu[u] % mod * sq(k / v)) % mod;
			}
			pw = 1;
			for (R int d = 1; pw && d <= n; ++d) {
				if (s[d] == '?' && t[d] == '?')
					pw = addMod(pw, pw);
				else if (s[d] != '?' && t[d] != '?' && s[d] != t[d])
					pw = 0;
			}
			ans = (ans + w * (comb(sc + tc, ta + tc - sa + i) + mod - pw)) % mod;
			ans = (ans + pw * sq(calc(k))) % mod;
		}
		if (sign(i) * sign(j) != 1) continue;
		ans = (ans + calc(k / ((i < 0 ? j : i) / getGcd(i, j))) * comb(sc + tc, ta + tc - sa + i)) % mod;
	}
	cout << ans << endl;
	return 0;
}
posted @ 2020-04-21 23:43  suwakow  阅读(208)  评论(0编辑  收藏  举报
Live2D