CodeForces 1286E Fedya the Potter Strikes Back
KMP 好题。
思路
合法的子区间其实就是原串的 \(\mathrm{border}\),考虑维护 \(\mathrm{border}\) 的集合。每次加入一个字符,就保留原来合法的 \(\mathrm{border}\) 并加入新的合法 \(\mathrm{border}\)(如果 \(s_1 = s_i\))。
重点在于如何删掉不合法的 \(\mathrm{border}\)。设 \(fa_{i,j}\) 表示 \([1,pos+1]\) 能被 \([i-pos+1,i]\) 和字符 \(j\) 拼接成的 \(pos\) 最大值,每次可以 \(O(26)\) 处理。枚举 \(i - 1\) 后接上的 \(\ne s_i\) 的字符,暴力跳 \(fa\) 删前缀即可。用一个 map
维护每个 \(w_i\) 出现的次数,则删完之后把 map
中 \(> w_i\) 的权值全部修改为 \(w_i\) 即可。
线段树维护区间 \(w_i\) 最小值,时间复杂度为 \(O(n \log n + 26n)\)。
代码
code
/*
p_b_p_b txdy
AThousandMoon txdy
AThousandSuns txdy
hxy txdy
*/
#include <bits/stdc++.h>
#define pb push_back
#define fst first
#define scd second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;
struct bignum {
const int P = 10000;
int a[10], n;
void add(ll x) {
int pos = 0;
while (x) {
a[pos] += x % P;
x /= P;
++pos;
}
for (int i = 0; i < 9; ++i) {
a[i + 1] += a[i] / P;
a[i] %= P;
}
n = 9;
while (n && !a[n]) {
--n;
}
}
ll mod(ll p) {
ll res = 0;
for (int i = n; ~i; --i) {
res = (res * P + a[i]) % p;
}
return res;
}
void write() {
printf("%d", a[n]);
for (int i = n - 1; ~i; --i) {
printf("%04d", a[i]);
}
putchar('\n');
}
} ans;
const int maxn = 600100;
int n, a[maxn], b[maxn], tree[maxn << 2], fail[maxn];
int fa[maxn][26];
map<ll, ll> mp;
void pushup(int x) {
tree[x] = min(tree[x << 1], tree[x << 1 | 1]);
}
void update(int rt, int l, int r, int x, int y) {
if (l == r) {
tree[rt] = y;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
update(rt << 1, l, mid, x, y);
} else {
update(rt << 1 | 1, mid + 1, r, x, y);
}
pushup(rt);
}
int query(int rt, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[rt];
}
int mid = (l + r) >> 1, res = 2e9;
if (ql <= mid) {
res = min(res, query(rt << 1, l, mid, ql, qr));
}
if (qr > mid) {
res = min(res, query(rt << 1 | 1, mid + 1, r, ql, qr));
}
return res;
}
void solve() {
char op[9];
scanf("%d%s%d", &n, op, &b[1]);
a[1] = op[0] - 'a';
update(1, 1, n, 1, b[1]);
ans.add(b[1]);
ans.write();
ll cur = 0;
for (int i = 2, j = 0; i <= n; ++i) {
scanf("%s%d", op, &b[i]);
a[i] = (op[0] - 'a' + ans.mod(26)) % 26;
b[i] ^= ans.mod(1LL << 30);
update(1, 1, n, i, b[i]);
ans.add(query(1, 1, n, 1, i));
if (a[1] == a[i]) {
cur += b[i];
++mp[b[i]];
}
while (j && a[i] != a[j + 1]) {
j = fail[j];
}
if (a[i] == a[j + 1]) {
++j;
}
fail[i] = j;
for (int k = 0; k < 26; ++k) {
fa[i][k] = fa[fail[i]][k];
}
fa[i][a[fail[i] + 1]] = fail[i];
for (int k = 0; k < 26; ++k) {
if (a[i] == k) {
continue;
}
for (int u = fa[i - 1][k]; u; u = fa[u][k]) {
int mn = query(1, 1, n, i - u, i - 1);
cur -= mn;
--mp[mn];
}
}
ll cnt = 0;
vector<ll> tv;
for (auto it = mp.upper_bound(b[i]); it != mp.end(); ++it) {
tv.pb(it->fst);
cnt += it->scd;
cur -= it->fst * it->scd;
}
for (ll x : tv) {
mp.erase(x);
}
cur += cnt * b[i];
mp[b[i]] += cnt;
ans.add(cur);
ans.write();
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}