【[HAOI2016]找相同字符】
其实这道题跟[AHOI2013]差异很像
其实这个问题的本质就是让你算所有后缀的\(lcp\)长度之和,但是得来自两个不同的字符串
先把两个字符串拼起来做一遍\(SA\),由于我们多算了来自于同一个串内的情况
于是在分别对这两个串建\(SA\),减掉这两次算出来的答案
现在的问题转化为求出\(height\)数组所有子区间的最小值的和
我们可以考虑一个动态往序列末尾加数的过程
也就是我们往末尾加一个数都会和之前所有的数形成一个新的区间
考虑快速算出这些区间的最小值的和
我们可以对每一个数存储一个\(a_i\),表示\(i\)到当前序列末尾的最小值是多少
我们每次加入一个数可以对更新一下所有的\(a_i\),把所有比当前加入的数大的\(a_i\)变成当前数就好了
这不就\(T\)了吗
我们发现我们只需要求出所有\(a_i\)的和,并不需要关心这个\(i\)来自哪里,于是我们可以把相等的\(a_i\)放在一起计算,也就是每次新加入一个数就暴力扫一遍把那些比当前加入数大的合并到一个\(a_i\)里
看起来复杂度并不科学,但是最坏情况下就相当于是一个线段树的复杂度了,\(O(n)\)的,跑的还挺快的
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define re register
#define LL long long
#define maxn 400005
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
int het[maxn],sa[maxn],rk[maxn],tp[maxn],tax[maxn];
char S[maxn],T[maxn];
int L1,L2,n,m;
LL a[maxn],cnt[maxn],top;
LL tot;
inline void qsort()
{
for(re int i=0;i<=m;i++) tax[i]=0;
for(re int i=1;i<=n;i++) tax[rk[i]]++;
for(re int i=1;i<=m;i++) tax[i]+=tax[i-1];
for(re int i=n;i;--i) sa[tax[rk[tp[i]]]--]=tp[i];
}
inline LL SA()
{
LL ans=0,sum=0;
memset(rk,0,sizeof(rk)),memset(sa,0,sizeof(sa)),memset(het,0,sizeof(het)),memset(tp,0,sizeof(tp));
memset(a,0,sizeof(a)),memset(cnt,0,sizeof(cnt)),top=0;
m=2500;
for(re int i=1;i<=n;i++) rk[i]=S[i],tp[i]=i;
qsort();
for(re int w=1,p=0;p<n;m=p,w<<=1)
{
p=0;
for(re int i=1;i<=w;i++) tp[++p]=n-w+i;
for(re int i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
qsort();
for(re int i=1;i<=n;i++) std::swap(tp[i],rk[i]);
rk[sa[1]]=p=1;
for(re int i=2;i<=n;i++) rk[sa[i]]=(tp[sa[i-1]]==tp[sa[i]]&&tp[sa[i-1]+w]==tp[sa[i]+w])?p:++p;
}
int k=0;
for(re int i=1;i<=n;i++)
{
if(k) --k;
int j=sa[rk[i]-1];
while(S[i+k]==S[j+k]) ++k;
het[rk[i]]=k;
}
for(re int i=2;i<=n;i++)
{
LL now=1;
while(top&&het[i]<=a[top])
now+=cnt[top],sum-=a[top]*cnt[top],top--;
cnt[++top]=now;
a[top]=het[i];
sum+=cnt[top]*a[top];
ans+=sum;
}
return ans;
}
int main()
{
scanf("%s",S+1);scanf("%s",T+1);
L1=strlen(S+1),L2=strlen(T+1);n=L1+L2+1;
S[L1+1]='z'+1;
for(re int i=1;i<=L2;i++) S[i+L1+1]=T[i];
tot+=SA();
for(re int i=L1+1;i<=n;i++) S[i]=0;
n=L1;tot-=SA();
for(re int i=1;i<=L2;i++) S[i]=T[i];
for(re int i=L2+1;i<=n;i++) S[i]=0;
n=L2;
tot-=SA();
printf("%lld\n",tot);
return 0;
}