【[AHOI2013]差异】

这个题一看就是为后缀家族设计的

我们看到我们要求的这个柿子

\[\sum_{i=1}^n\sum_{j=i+1}^nT_i+T_j-2\times lcp(T_i,T_j) \]

显然的是前面的那些东西是个定值

就是保证每一个长度都会被其他长度算到,也就是算到\(n-1\)

于是把前面那些东西拿出来就是

\[\frac{(n+1)(n-1)n}{2} \]

之后再看后面那些东西

所有后缀的\(lcp\)的长度?

先来考虑一下如何求两个后缀的\(lcp\)

哈希+二分啊\(SA\)

对于后缀\(i,j\),他们的\(lcp\)长度就是\(min(heighht[rk[i]+1]...height[rk[j]])\)

于是现在的问题转化为求出\(height\)数组所有子区间的最小值的和

我们可以考虑一个动态往序列末尾加数的过程

也就是我们往末尾加一个数都会和之前所有的数形成一个新的区间

考虑快速算出这些区间的最小值的和

我们可以对每一个数存储一个\(a_i\),表示\(i\)到当前序列末尾的最小值是多少

我们每次加入一个数可以对更新一下所有的\(a_i\),把所有比当前加入的数大的\(a_i\)变成当前数就好了

这不就\(T\)了吗

我们发现我们只需要求出所有\(a_i\)的和,并不需要关心这个\(i\)来自哪里,于是我们可以把相等的\(a_i\)放在一起计算,也就是每次新加入一个数就暴力扫一遍把那些比当前加入数大的合并到一个\(a_i\)

看起来复杂度并不科学,但是最坏情况下就相当于是一个线段树的复杂度了,\(O(n)\)的,跑的还挺快的

代码

#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define maxn 500005
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define pt putchar(1)
inline int read()
{
	char c=getchar();int x=0;
	while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
int n,m,top;
LL ans=0,sum=0;
char S[maxn];
int tax[maxn],sa[maxn],rk[maxn],tp[maxn],height[maxn];
int L[maxn],R[maxn],st[maxn];
int a[maxn],cnt[maxn];
LL pre[maxn];
inline void qsort()
{
	for(re int i=0;i<=m;i++) tax[i]=0;
	for(re int i=1;i<=n;i++) tax[rk[i]]++;
	for(re int i=1;i<=m;i++) tax[i]+=tax[i-1];
	for(re int i=n;i;--i) sa[tax[rk[tp[i]]]--]=tp[i];
}
int main()
{
	scanf("%s",S+1);
	n=strlen(S+1);
	m=75;
	for(re int i=1;i<=n;i++) rk[i]=S[i]-'a'+1,tp[i]=i;
	qsort();
	for(re int w=1,p=0;p<n;m=p,w<<=1)
	{
		p=0;
		for(re int i=1;i<=w;i++) tp[++p]=n-w+i;
		for(re int i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
		qsort();
		for(re int i=1;i<=n;i++) std::swap(tp[i],rk[i]);
		rk[sa[1]]=p=1;
		for(re int i=2;i<=n;i++) rk[sa[i]]=(tp[sa[i-1]]==tp[sa[i]]&&tp[sa[i-1]+w]==tp[sa[i]+w])?p:++p;
	}
	int k=0;
	for(re int i=1;i<=n;i++)
	{
		if(k) --k;
		int j=sa[rk[i]-1];
		while(S[i+k]==S[j+k]) ++k;
		height[rk[i]]=k;
	}
	ans+=height[2];
	a[1]=ans;cnt[1]=1,sum=ans;
	top=1;
	for(re int i=3;i<=n;i++)  
	{
		int now=1;
		while(top&&height[i]<=a[top]) 
			now+=cnt[top],sum-=a[top]*cnt[top],top--;
		cnt[++top]=now;
		a[top]=height[i];
		sum+=cnt[top]*a[top];
		ans+=sum;
	}
	printf("%lld\n",(LL)(n-1)*(LL)(n+1)*(LL)n/2ll-2ll*ans);
	return 0;
}
posted @ 2019-01-01 19:33  asuldb  阅读(230)  评论(0编辑  收藏  举报