AtCoder Regular Contest 110 E Shorten ABC
考虑把 \(\text{A}\) 看成 \(1\),\(\text{B}\) 看成 \(2\),\(\text{C}\) 看成 \(3\),那么一次操作相当于选择一个 \(a_i \ne a_{i+1}\) 的 \(i\),将 \(a_i\) 和 \(a_{i+1}\) 替换成一个数 \(a_i \oplus a_{i+1}\)。
那么题目相当于把 \(a\) 划分成若干段,满足每段的异或和不为 \(0\) 且不是同一种字符或者长度为 \(1\)。将每段的异或和排成一个新数组 \(b\),对所有本质不同的 \(b\) 计数。
我们反过来观察对于一个固定的 \(b\),它有没有可能被形成。
考虑贪心,每次跳到下一个区间异或和为 \(b_i\) 的位置。并且要求最后剩下的一段异或和为 \(0\)。
设 \(nxt_{i,j}\) 为最小的 \(k\) 满足 \(a_i \oplus a_{i+1} \oplus \cdots \oplus a_k = j\),这个是可以线性预处理的。
那么设 \(f_i\) 为将前 \(i\) 个字符划分为若干段,并且每一段都是所有异或和相同的段中右端点最靠左的段。
转移是 \(f_{nxt_{i+1,j}} \gets f_i\)。
注意后面还有一段异或和为 \(0\),因此答案不仅仅是 \(f_n\)。
注意特判 \(a_i\) 都相等的情况,此时前面的 \(f_i\) 传递不到 \(f_n\),答案为 \(1\)。
code
// Problem: E - Shorten ABC
// Contest: AtCoder - AtCoder Regular Contest 110(Sponsored by KAJIMA CORPORATION)
// URL: https://atcoder.jp/contests/arc110/tasks/arc110_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 1000100;
const ll mod = 1000000007;
ll n, f[maxn], nxt[maxn][4], a[maxn];
char s[maxn];
void solve() {
scanf("%lld%s", &n, s + 1);
for (int i = 1; i <= n; ++i) {
if (s[i] == 'A') {
a[i] = 1;
} else if (s[i] == 'B') {
a[i] = 2;
} else {
a[i] = 3;
}
}
bool flag = 1;
for (int i = 2; i <= n; ++i) {
flag &= (a[i] == a[1]);
}
if (flag) {
puts("1");
return;
}
nxt[n + 1][0] = nxt[n + 1][1] = nxt[n + 1][2] = nxt[n + 1][3] = n + 1;
for (int i = n; i; --i) {
for (int j = 1; j <= 3; ++j) {
if (a[i] == j) {
nxt[i][j] = i;
} else {
nxt[i][j] = nxt[i + 1][j ^ a[i]];
}
}
}
f[0] = 1;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= 3; ++j) {
f[nxt[i][j]] = (f[nxt[i][j]] + f[i - 1]) % mod;
}
}
ll x = 0, ans = 0;
for (int i = n; i; --i) {
if (x == 0) {
ans = (ans + f[i]) % mod;
}
x ^= a[i];
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}