Luogu P3181 【[HAOI2016]找相同字符】
Description
Solution
这题就是让求两个串的相同子串的个数。
众所周知,字符串所有的子串就是字符串所有的后缀的所有前缀。
利用这个性质我们可以将问题转化,变成求两个字符串的所有后缀的\(lcp\)的长度的和。
求后缀的\(lcp\)我们可以使用\(SA\)。
将两个字符串连接起来,中间位置放一个\(1\)隔开,两个字符串分别染上不同的颜色。
现在答案就变成了求每一堆中颜色不同的后缀的\(lcp\)长度,考虑\(height\)数组的性质,即\(lcp(i, j) = min_{k = i + 1}^j height_k\),我们发现在每个块内\(lcp\)的长度具有单调性。
考虑每一堆中不同颜色之间的\(lcp\)长度就相当于先对于一种颜色算在它之前出现的另外一种颜色和它的\(lcp\),再按照另外一种颜色算一遍。
我们用单调栈维护这个\(lcp\)长度,用\(sum\)记录单调栈内元素对答案的总贡献,\(stack\)中保存两个信息——\(height\)值的大小和有多少个贡献和它相同的元素。
每次我们将当前的后缀\(i\)加入单调栈,如果\(sa_{i - 1}\)的颜色不是我们要算贡献的颜色,那么它就可以对当前颜色造成贡献,\(sum += height_i\)。
这时我们将现在的\(height_i\)入栈,考虑如果\(height_i <= height_{top}\),那么栈顶对答案的贡献就会变为\(height_i\),所以可以将\(height_i\)和\(height_{top}\)合并,对答案的贡献是\(height_i\),贡献个数是\(num_i + num_{top}\)。
合并完之后如果\(sa_i\)的颜色是我们要算贡献的颜色就加上所有栈内的贡献。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define int long long
typedef long long ll;
const int N = 400050;
int k, m, n, s[N], x[N], y[N], cnt[N], sa[N], rnk[N], height[N], col[N], l, stack[N][2], top;
char a[N], b[N];
ll ans, sum;
void Rsort()
{
for (int i = 0; i <= m; i++) cnt[i] = 0;
for (int i = 1; i <= n; i++) cnt[x[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[x[y[i]]]--] = y[i];
return;
}
void SA()
{
m = 128;
for (int i = 1; i <= n; i++) x[i] = s[i], y[i] = i;
Rsort();
m = 0;
for (int k = 1; m < n; k <<= 1)
{
m = 0;
for (int i = n - k + 1; i <= n; i++) y[++m] = i;
for (int i = 1; i <= n; i++) if (sa[i] > k) y[++m] = sa[i] - k;
Rsort();
for (int i = 1; i <= n; i++) y[i] = x[i];
m = x[sa[1]] = 1;
for (int i = 2; i <= n; i++)
{
if(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k])
x[sa[i]] = m;
else x[sa[i]] = ++m;
}
}
for (int i = 1; i <= n; i++) rnk[sa[i]] = i;
for (int i = 1; i <= n; i++)
{
height[rnk[i]] = height[rnk[i - 1]] - 1;
if (height[rnk[i]] < 0) height[rnk[i]] = 0;
while (s[i + height[rnk[i]]] == s[sa[rnk[i] - 1] + height[rnk[i]]]) height[rnk[i]]++;
}
}
signed main()
{
n = 0; ans = 0;
scanf("%s", a + 1);
l = strlen(a + 1);
for (int i = 1; i <= l; i++) s[++n] = a[i] + 1, col[i] = 1;
s[++n] = 1;
scanf("%s", b + 1);
l = strlen(b + 1);
for (int i = 1; i <= l; i++) s[++n] = b[i] + 1, col[n] = 2;
SA();
for (int color = 1; color <= 2; color++)
{
top = sum = 0;
for (int i = 1; i <= n; i++)
if (height[i] == 0) sum = top = 0;
else
{
int cnt = 0;
if (col[sa[i - 1]] != color)
{
cnt++;
sum += height[i];
}
while(top && height[i] <= stack[top][1])
{
cnt += stack[top][2];
sum -= 1LL * stack[top][2] * (stack[top][1] - height[i]);
top--;
}
stack[++top][1] = height[i]; stack[top][2] = cnt;
if (col[sa[i]] == color) ans += sum;
}
}
printf("%lld\n", ans);
return 0;
}