BZOJ 3238: [Ahoi2013]差异
BZOJ 3238: [Ahoi2013]差异
标签(空格分隔): OI-BZOJ OI-后缀自动机
Time Limit: 20 Sec
Memory Limit: 512 MB
Description
Input
一行,一个字符串S
Output
一行,一个整数,表示所求值
Sample Input
cacao
Sample Output
54
HINT
2<=N<=500000,S由小写英文字母组成
Solution####
后缀自动机,求出right集合大小,字符串S对答案的贡献为
-|S|*2*rightsize(S+'a')*rightsize(S+'b')-|S|*2*rightsize(S+'a')*rightsize(S+'c')....
先dp求出到达每个点不同的字符串的长度和,然后用结合律的思想算出加入两两不同字符后的rightsize乘积的和。
可以用后缀数组单调栈维护。
Code####
#include<iostream>
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<bitset>
#include<vector>
using namespace std;
#define PA pair<int,int>
int read()
{
int s=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){s=(s<<1)+(s<<3)+ch-'0';ch=getchar();}
return s*f;
}
//smile please
const int N=500005;
struct samm
{
int last,total,n;
int L[N*2],ch[N*2][27],fa[N*2];
int val[N*2],sum[N*2],t[N],h[N*2];
int s1[N*2],s2[N*2];
samm(){total=1,last=1;}
void insert(int C)
{
int p=last,now=last=++total;
L[now]=L[p]+1;n=L[now];val[now]=1;
for(;p&&!ch[p][C];p=fa[p])
ch[p][C]=now;
if(!p)fa[now]=1;
else
if(L[ch[p][C]]==L[p]+1)
fa[now]=ch[p][C];
else
{int ne=++total,Q=ch[p][C];
memcpy(ch[ne],ch[Q],sizeof(ch[Q]));
fa[ne]=fa[Q];
L[ne]=L[p]+1;
fa[Q]=fa[now]=ne;
for(;p&&ch[p][C]==Q;p=fa[p])
ch[p][C]=ne;
}
}
void pre()
{for(int i=1;i<=total;i++)t[L[i]]++;
for(int i=1;i<=n;i++)t[i]+=t[i-1];
for(int i=total;i;i--)h[t[L[i]]--]=i;
for(int i=total,x;i;i--)sum[x=h[i]]=val[x];
for(int i=total,x;i;i--)
{x=h[i];
sum[fa[x]]+=sum[x];
}
sum[0]=0;
}
long long solve()
{
s2[1]=1;
for(int i=1,x;i<=total;i++)
{x=h[i];
for(int j=0;j<27;j++)
s1[ch[x][j]]+=s1[x]+s2[x],
s2[ch[x][j]]+=s2[x];
}
long long ans=0;
for(int i=1,x;i<=total;i++)
{x=h[i];long long su=0;
for(int j=0;j<27;j++)
ans+=su*sum[ch[x][j]]*2*s1[x],
su+=sum[ch[x][j]];
}
return ans;
}
}a;
char z[N];
int n,k,t;
int main()
{
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%s",z);
n=strlen(z);
for(int i=0;i<n;i++)
a.insert(z[i]-'a');
a.insert(26);
a.pre();
cout<<((long long)(n+1)*n*(n-1)/2-a.solve());
//fclose(stdin);
//fclose(stdout);
return 0;
}