题解 CF1450H1 【Multithreading (Easy Version)】
首先假设没有 \(?\) 。
首先蒟蒻先给出一个很容易想到的贪心:
每一次选择相邻的两个颜色相同的位置,然后把他们配对
如果配对不上,那么现在的状态一定就是 \(wbwbwbwb...\) ,需要有 \(\frac{cnt}{2}\) 个交点。\(cnt\) 为 \(b\) 的个数。
于是整个问题的答案就是 \(|\frac{cnta - cntb}{2}|\) ,\(cnta\) 是 \(b\) 在下标为奇数位置出现的次数,\(cntb\) 是 \(b\) 在下标为偶数位置出现的次数。
如何证明这个是最小的?
首先可以让整张图同色的相交的线段变成不相交,答案严格不增。
考虑数学归纳法:
-
一张空图,或一张只有白色点的图,答案 \(= 0 \ge |\frac{cnta - cntb}{2}|\)
-
非空图,首先找到两个相连的黑点,且在他们在黑点中的位置相邻(一定可以找到):
-
两个点一个在下标奇数处一个在下标偶数处:那么两个点一边都是白点,可以自己匹配,另一边的答案 \(\ge |\frac{(cnta - 1) - (cntb - 1)}{2}| = |\frac{cnta - cntb}{2}|\) 。于是成立。
-
两个点都在下标奇数处或都在下标偶数处:那么两个点一边都是白点但是会多出一个点,那么这个白点和其对应点会和现在找到的这条相连的边产生一个交点。答案(不妨选中的是两个下标奇数的点) \(\ge 1 + |\frac{(cnta - 2) - cntb}{2}| = 1 + |\frac{cnta - cntb}{2} - 1| \ge |\frac{cnta - cntb}{2} |\)
-
证毕。
统计现在已经填了的 \(b\) 奇数下标减偶数下标为 \(cnta\),\(?\) 在奇数下标的个数为 \(cntb\),\(?\) 在偶数下标的个数为 \(cntc\)。
直接枚举 ?
中 b
的个数即可得到下面柿子:
这个东西就是可以卷积的了。
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db long double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
const int N = 6e5 + 7;
const int mod = 998244353;
const int G = 3;
const int iG = (mod + 1) / G;
int qpow(int x, int y) {
int res = 1;
for(; y; x = 1ll * x * x % mod, y >>= 1) if(y & 1) res = 1ll * res * x % mod;
return res;
}
int ny(int x) { return qpow(x, mod - 2); }
int pp[N];
void NTT(int *f, int len, int flag) {
L(i, 0, len - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
for(int i = 2; i <= len; i <<= 1)
for(int j = 0, l = (i >> 1), ch = qpow(flag == 1 ? G : iG, (mod - 1) / i); j < len; j += i)
for(int k = j, now = 1; k < j + l; k ++) {
int pa = f[k], pb = 1ll * f[k + l] * now % mod;
f[k] = (pa + pb) % mod, f[k + l] = (pa - pb + mod) % mod, now = 1ll * now * ch % mod;
}
if(flag == -1) {
int nylen = ny(len);
L(i, 0, len - 1) f[i] = 1ll * f[i] * nylen % mod;
}
}
void George1123(int *f, int *g, int *ans, int lena, int lenb) { // Tornado king
int lim; for(lim = 1; lim <= lena + lenb; lim <<= 1);
L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1)));
NTT(f, lim, 1), NTT(g, lim, 1);
L(i, 0, lim - 1) ans[i] = 1ll * f[i] * g[i] % mod;
NTT(ans, lim, -1);
}
int n, m, cnta, cntb, cntc, jc[N], njc[N], ans, f[N], g[N], h[N];
int C(int x, int y) { return 1ll * jc[x] * njc[y] % mod * njc[x - y] % mod; }
char s[N];
int main() {
scanf("%d%d", &n, &m);
scanf("%s", s + 1);
jc[0] = njc[0] = 1;
L(i, 1, n) jc[i] = 1ll * jc[i - 1] * i % mod, njc[i] = ny(jc[i]);
L(i, 1, n) {
if(s[i] == 'b') {
if(i % 2) cnta ++;
else cnta --;
}
else if(s[i] == '?') {
if(i % 2) cntb ++;
else cntc ++;
}
}
L(i, 0, cntb) f[i] = C(cntb, i);
L(i, 0, cntc) g[i] = C(cntc, i);
George1123(f, g, h, cntb, cntc);
L(i, 0, cntb + cntc) if((cnta + i - cntc) % 2 == 0) (ans += 1ll * abs(cnta + i - cntc) / 2 * h[i] % mod) %= mod;
ans = 1ll * ans * qpow(ny(2), cntb + cntc - 1) % mod;
printf("%d\n", ans);
return 0;
}
现在已经可以通过 H1
了,但是给一种更好写,为 H2
做基础的方法:
考虑如何优化得到的卷积式 \(f_t = \sum\limits_{i = 0}^{a} C_{a}^i \sum\limits_{j = 0}^{b}C_{b}^{j} [i + j = t]\)
观察规律发现 \(f_t = C_{a+b}^t\) ,如何证明呢?
考虑其组合意义。\(C_{a+b}^t\) 相当于从 \(a + b\) 个球中取出 \(t\) 个,那么我们可以枚举从前 \(a\) 个取出的个数和前 \(b\) 个取出的个数,也就是 \(\sum\limits_{i = 0}^{a} C_{a}^i \sum\limits_{j = 0}^{b}C_{b}^{j} [i + j = t]\)。
(其实就是范德蒙德卷积)
于是答案就是 \(\frac{1}{2^{cntb + cntc}} \sum\limits_{i = 0, i \equiv cnta - cntc \pmod 2}^{cnta + cntb} |cnta - cntc + i| C_{cntb + cntc}^{i}\)
这个式子看起来很不舒服,让 \(cnt = cntb + cntc, t = cntc - cnta\)
\(ans = \frac{1}{2^{cnt}} \sum\limits_{i = 0, i \equiv t \pmod 2}^{cnt} |i - t| C_{cnt}^{i}\)
所以我们得到了更简洁的代码:
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db long double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
const int N = 2e5 + 7;
const int mod = 998244353;
int qpow(int x, int y) {
int res = 1;
for(; y; x = 1ll * x * x % mod, y >>= 1) if(y & 1) res = 1ll * res * x % mod;
return res;
}
int ny(int x) { return qpow(x, mod - 2); }
int n, m, t, cnt, jc[N], njc[N], ans;
int C(int x, int y) { return 1ll * jc[x] * njc[y] % mod * njc[x - y] % mod; }
char s[N];
int main() {
scanf("%d%d", &n, &m);
scanf("%s", s + 1);
jc[0] = njc[0] = 1;
L(i, 1, n) jc[i] = 1ll * jc[i - 1] * i % mod, njc[i] = ny(jc[i]);
L(i, 1, n) {
if(s[i] == 'b') {
if(i % 2) t --;
else t ++;
}
else if(s[i] == '?') {
cnt ++;
if(i % 2 == 0) t ++;
}
}
L(i, 0, cnt) if((t + i) % 2 == 0) (ans += 1ll * abs(t - i) * C(cnt, i) % mod) %= mod;
ans = 1ll * ans * qpow(ny(2), cnt) % mod;
printf("%d\n", ans);
return 0;
}