P9576 「TAOI-2」Ciallo~(∠・ω< )⌒★
题意
求字符串 \(s\) 删去每个区间后字符串 \(t\) 出现的次数之和。
Sol
不难注意到答案分为两类:
- 删去区间后,一个前缀和一个后缀刚好拼成 \(t\)。
- 存在于前缀之中,本身就与 \(t\) 匹配,以及存在于后缀之中,与 \(t\) 匹配的串。
第二类明显是 \(trivial\) 的。
预处理前缀后缀直接算就好。
仔细思考。不难发现可以预处理 \(f_i, g_i\) 分别表示以 \(i\) 开头的最长前缀,以及以 \(i\) 开头的最长后缀。
答案即为:
\[\sum_{i = 1} ^ {n} \sum_{j = i + |T|} ^ {n} [f_i + g_i > |T|] (f_i + g_i - |T| + 1)
\]
考虑扫描线快速计算。
从右向左枚举 \(i\),用两个树状数组维护 \(|T| - g_i\) 即可。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <cassert>
#define int long long
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
string read_() {
string ans;
char c = getchar();
while (c < 'a' || c > 'z')
c = getchar();
while (c >= 'a' && c <= 'z')
ans += c, c = getchar();
return ans;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
const int N = 8e5 + 5, mod = 147744151;
namespace Hash {
int gethash(string s) {
int ans = 0;
for (auto x : s)
ans = ans * 131ll % mod + x;
return ans;
}
array <int, N> idx;
void init() {
idx[0] = 1;
for (int i = 1; i <= 8e5; i++)
idx[i] = idx[i - 1] * 131ll % mod;
}
void getarray(array <int, N> &isl, string &s) {
for (int i = 1; i < (int)s.size(); i++)
isl[i] = isl[i - 1] * 131ll % mod + s[i];
}
int query(array <int, N> &hs, int l, int r) {
if (!l) return 0;
return (hs[r] - hs[l - 1] * idx[r - l + 1] % mod + mod) % mod;
}
}
namespace Bit1 {
array <int, N> edge;
int lowbit(int x) {
return x & -x;
}
void modify(int x, int y, int n) {
assert(x);
while (x <= n) {
edge[x] += y;
x += lowbit(x);
}
return;
}
int query(int x) {
int ans = 0;
while (x) {
ans += edge[x];
x -= lowbit(x);
}
return ans;
}
}
namespace Bit2 {
array <int, N> edge;
int lowbit(int x) {
return x & -x;
}
void modify(int x, int y, int n) {
while (x <= n) {
edge[x] += y;
x += lowbit(x);
}
return;
}
int query(int x) {
int ans = 0;
while (x) {
ans += edge[x];
x -= lowbit(x);
}
return ans;
}
}
array <int, N> f, g;
array <int, N> hsS, hsT;
signed main() {
string s = " " + read_(), t = " " + read_();
int n = s.size() - 1, m = t.size() - 1;
Hash::init();
Hash::getarray(hsS, s), Hash::getarray(hsT, t);
for (int i = 1; i <= n; i++) {
int l = i, r = n;
int ans = i;
while (l <= r) {
int mid = (l + r) >> 1;
if (Hash::query(hsS, i, mid) ==
Hash::query(hsT, 1, mid - i + 1)) ans = mid + 1, l = mid + 1;
else r = mid - 1;
}
f[i] = ans - i;
l = i - m - 1, r = i;
ans = i;
while (l <= r) {
int mid = (l + r) >> 1;
/* write() */
if (Hash::query(hsS, mid, i) ==
Hash::query(hsT, m - i + mid, m)) ans = mid - 1, r = mid - 1;
else l = mid + 1;
}
g[i] = i - ans;
f[i] = min(f[i], m - 1);
g[i] = min(g[i], m - 1);
}
int ans = 0;
for (int i = 1; i <= n - m + 1; i++)
if (Hash::query(hsS, i, i + m - 1) == Hash::query(hsT, 1, m))
ans += (i - 1) * i / 2 + (n - i - m + 2) * (n - i - m + 1) / 2;
/* write(ans), puts("@"); */
/* for (int i = 1; i <= n; i++) */
/* for (int j = i + m; j <= n; j++) */
/* ans += (f[i] + g[j] >= m) * (f[i] + g[j] - m + 1); */
/* write(ans), puts(""); */
for (int i = n - m; i; i--) {
Bit1::modify(g[i + m] + 1, 1, m + 1);
Bit2::modify(g[i + m] + 1, g[i + m], m + 1);
int tp = m - f[i], len = Bit1::query(m + 1) - Bit1::query(tp);
ans += len * f[i] - len * m + len + Bit2::query(m + 1) - Bit2::query(tp);
/* write(Bit2::query(m + 1) - Bit2::query(tp)), puts(""); */
}
write(ans), puts("");
return 0;
}