AtCoder Grand Contest 022 E Median Replace
考虑对于一个确定的串怎么判断合法性。
容易发现删到某个时刻若 \(1\) 的个数大于 \(0\) 的个数了,因为我们肯定不会蠢到在不是全 \(1\) 的时候删 \(111\),所以 \(c_1 - c_0\) 在不是全 \(1\) 的时候至少是不会变小的。
所以我们的目标就是让 \(c_1 - c_0\) 尽可能大。
发现删 \(000\) 这个操作看起来很优,它能使 \(c_1 - c_0\) 增加 \(2\)。所以我们只要贪心地删 \(000\) 就可以了……吗?
考虑这个串:\(0010011\),删不了 \(000\),并且 \(c_1 - c_0 < 0\)。但是实际上这个串可以先删中间的 \(010\) 变成 \(00011\),再删 \(000\)。发现问题出在没有删 \(010\)。发现删 \(010\) 之后 \(c_1 - c_0\) 不变,并且它左侧和右侧的 \(0\) 可以拼在一起形成一个新的 \(000\)。感性理解一下这个策略是不劣的。
也就是说有 \(000\) 或者 \(010\) 就删。最后判 \(c_1 - c_0\) 是否 \(> 0\) 即可。
考虑计数。可以考虑设计一个自动机,并且记录当前考虑到的前缀的 \(c_1 - c_0\) 的值。我们有:
注意这里 \(1\) 可以变成 \(\varnothing\) 是因为这个 \(1\) 之后不会参与删除了,所以直接忽略它,并且让 \(c_1 - c_0\) 增加 \(1\)。
还有:
所以自动机状态数为 \(5\)。先把当前在自动机上的结点记入状态。还要记一个当前 \(c_1 - c_0\) 的值。
发现若 \(c_1 - c_0 \ge 3\) 那么把后面的全部删了之后一定满足 \(c_1 - c_0 > 0\)。同时由于我们的策略,不会出现 \(c_1 - c_0 \le -3\) 的情况。所以 \(c_1 - c_0 \in [-2, 2]\),这样就可以记入状态了。
时间复杂度 \(O(n)\)。
code
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 300100;
const int mod = 1000000007;
const int ch[5][2] = {1, 0, 2, 3, 1, 4, 1, 0, 2, 0}, g[5][2] = {-1, 1, -1, 1, 1, 1, -1, 1, -1, 1};
int n, f[maxn][5][5], pw[maxn], a[maxn];
char s[maxn];
inline void upd(int &x, int y) {
x = (x + y < mod ? x + y : x + y - mod);
}
void solve() {
scanf("%s", s + 1);
n = strlen(s + 1);
pw[0] = 1;
for (int i = 1; i <= n; ++i) {
pw[i] = pw[i - 1] * 2 % mod;
}
for (int i = n; i; --i) {
a[i] = a[i + 1] + (s[i] == '?');
}
f[0][2][0] = 1;
int ans = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 0; j <= 4; ++j) {
for (int u = 0; u < 5; ++u) {
if (!f[i - 1][j][u]) {
continue;
}
for (int o = 0; o <= 1; ++o) {
if (s[i] == '0' + (o ^ 1)) {
continue;
}
int v = ch[u][o], t = g[u][o];
if (j + t == 5) {
ans = (ans + 1LL * f[i - 1][j][u] * pw[a[i + 1]]) % mod;
} else {
upd(f[i][j + t][v], f[i - 1][j][u]);
}
}
}
}
}
for (int i = 2; i <= 4; ++i) {
for (int j = 0; j < 5; ++j) {
upd(ans, f[n][i][j]);
}
}
printf("%d\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}