POJ3415 Common Substrings 后缀数组 + 单调栈
题目大意:
给出两个字符串S和T,求出两个字符串之间有多少长度大于K的公共子区间。
由于每一个子串都是包含在某一个后缀的前缀里面,求出sa和height了之后,我们可以将height进行分组,height < k为分割线,这样一来每个组内都是height >= k的后缀。我们知道长度不小于k的公共子串是两个后缀的前缀,并且它们一定在同一组内。所谓的分组就是for循环从2->n的时候遇到height<k的时候就开始下一组。
转化成easy版本进行思考:
先将问题转化为一个简单版本进行思考,即求一个字符串内求有多少长度大于等于K的公共子区间,我们知道后缀数组的height数组是用来维护排名相邻的后缀最长公共前缀的数组。那么如果这个数组是单调递增的,那么答案如何计算?
假设height数组为[0,1,2,3,4,5,6,7],K为3。我们从左往右走,那么当height小于3的时候肯定是不纳入答案计算的。当height等于3的时候对答案ans的贡献为height - K + 1 = 1,当我们走到height[i]等于4的时候会发现当前后缀造成的贡献不仅是height[i] - K + 1,而且要加上height[i - 1] - K + 1,因为当前后缀不仅和上一个后缀有公共前缀,还和上上个后缀有公共前缀,乃至上上上个后缀...有公共前缀,并且有些能产生贡献。当前考虑的是简化版本的单调递增height数组,所以再加上一个height[i - 1] - K + 1。维护每个后缀只和前面一个后缀造成贡献的和sum,在下一次统计答案ans = height - k + 1 + sum,然后更新sum += height - K + 1;
再考虑不是一直单调递增的height数组该如何进行统计,我们可以用单调栈维护一个公共前缀递增的单调栈,如果height递增我们就像上面一样更新前缀和。然后每次遇到height小于栈顶了我们就弹出栈顶元素并且更新前缀和贡献。如何更新弹出元素的贡献呢?因为我们知道任意两个后缀的公共前缀是他们之间所有高度数组的最小值。所以我们在弹出的时候我们需要让sum -(stack[top] - height)。这样一来第一次被弹出所减小的贡献就被消除了。同时我们还要在新加入栈的位置维护一下被弹出了多少个后缀,这些后缀被弹出之后不能认为这些后缀做不了贡献,我们知道他们在被当前这个height弹出之后实际上就是他们变成了和当前height长度一样的height数组,要弹出也是因为长度大于当前height的话会导致当前height下面的答案和这些被弹出的值贡献已经不是这些值了,因为中间出现了更小的height所以那些大的对下面的贡献也应该变成这个更小的height。
考虑初始版本该如何做,将两个字符拼在一起中间用'Z' + 1衔接,这样可以避免第一个字符串的结尾字符连上第二个字符串的头部导致前缀出现额外的公共部分。对于每个位置计算这个位置之前的第二个字符串对第一个字符串造成的贡献,我们会发现漏掉了每个位置计算每个位置之前的第一个字符串对第二个字符串造成的贡献,所以倒着再做一遍即可。
我们可以先算出第一个字符串的长度len,sa[i] <= len的就是第一个字符串,sa[i] > len就是第二个字符串(不用管衔接符,它无法和任何一个后缀拥有公共子串)。
代码如下:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#define ll long long
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
const long double eps = 1e-9;
const int N = 2e5 + 10, M = 2e5 + 10;
int n , m, _;
char str[M];
int s[M];
struct SA{
std::vector<int> sa, x, rk, height, c, y;
SA(int n) : sa(n, 0), x(n, 0), rk(n, 0), height(n, 0), c(n, 0), y(n, 0) {};
inline void get_sa(){
for(int i = 1; i <= m; i ++) c[i] = 0;
for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++;
for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
for(int i = n; i; i --) sa[c[x[i]] --] = i;//数组的值代表的排名,下标代表在原数组的哪个位置
for(int k = 1; k <= n; k <<= 1){
int num = 0;
for(int i = n - k + 1; i <= n; i ++) y[++ num] = i;
for(int i = 1; i <= n; i ++)
if(sa[i] > k) //排名为i的数组下标大于k
y[++ num] = sa[i] - k;//排名为i的第二关键字的第一关键字位置
for(int i = 1; i <= m; i ++) c[i] = 0;
for(int i = 1; i <= n; i ++) c[x[i]] ++;
for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
for(int i = n; i; i --) sa[c[x[y[i]]] --] = y[i], y[i] = 0;
//sa记录上述排名对应的下标为 y[i],从后往前枚举第二关键字的排名,使得第一关键字相同的后缀也可以依靠第二关键字区分
swap(x, y);//接下来需要更新x数组,但是y数组没用了,所以把信息转移到y数组
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i ++)
x[sa[i]] = (y[sa[i - 1]] == y[sa[i]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
if(num == n) break;
m = num;//更新离散化后的rank范围
}
}
inline void get_height(){
for(int i = 1; i <= n; i ++) rk[sa[i]] = i;
for(int i = 1, k = 0; i <= n; i ++){
if(rk[i] == 1) continue;
if(k) k --;
int j = sa[rk[i] - 1];///排名在i前一个的后缀的下标
while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++;
height[rk[i]] = k;
}
}
inline void init(){
for(int i = 1; i <= n; i ++) s[i] = str[i] - 'A' + 1;
m = 'z' - 'A' + 5;
get_sa();
get_height();
}
inline ll query(){//求本质不同子串数量
ll res = 0;
for(int i = 1; i <= n; i ++) res += n - sa[i] + 1 - height[i];
return res;
}
};
SA Sa(M);
int k;
int cnt[M], stk[M];
inline void solve(){
while(std::cin >> k){
if(!k) return ;
ll ans = 0;
std::cin >> str + 1;
int len = strlen(str + 1);
str[len + 1] = 'Z' + 1;
std::cin >> str + len + 2;
n = strlen(str + 1);
Sa.init();
ll tot = 0;
int top = 0;
//先计算第一个字符串对第二个字符串造成的贡献
for(int i = 2; i <= n; i ++){
if(Sa.height[i] < k){
tot = top = 0;
continue;
}
int num = (Sa.sa[i - 1] <= len);
//上一个后缀是第二个字符串(上一个字符是需要统计的当前height才有贡献)
if(Sa.sa[i - 1] <= len) tot += Sa.height[i] - k + 1;
while(top && stk[top] > Sa.height[i]){
num += cnt[top];
tot -= (stk[top] - Sa.height[i]) * cnt[top];
top --;
}
//std::cout << num << '\n';
stk[++ top] = Sa.height[i];
cnt[top] = num;
if(Sa.sa[i] > len) ans += tot;
}
top = tot = 0;
//计算第二个字符串对第一个字符串造成的贡献
for(int i = 2; i <= n; i ++){
if(Sa.height[i] < k){
tot = top = 0;
continue;
}
//上一个后缀是第二个字符串(上一个字符是需要统计的当前height才有贡献)
int num = (Sa.sa[i - 1] > len);
if(Sa.sa[i - 1] > len) tot += Sa.height[i] - k + 1;
while(top && stk[top] > Sa.height[i]){
num += cnt[top];
tot -= (stk[top] - Sa.height[i]) * cnt[top];
top --;
}
stk[++ top] = Sa.height[i];
cnt[top] = num;
if(Sa.sa[i] <= len) ans += tot;
}
std::cout << ans << '\n';
}
}
signed AC{
_ = 1;
//std::cin >> _;
while(_ --)
solve();
return 0;
}