Luogu-4248 [AHOI2013]差异

\(\sum_{i<j}len(i)+len(j)\)比较简单,稍微想想就出来了,问题在于怎么求任意两个后缀的\(lcp\)长度之和

因为求\(lcp\)实际上就是一个对\(h\)数组求区间最小值的过程,这就可以考虑计算对于每一个\(h\),他对答案做出的贡献,可以看出以\(h[x]\)作为最小值的区间\([l,r]\)中,任意一对\(i\in[l,x],j\in[x,r]\)\(lcp\)都是他,总的对数就是贡献。\(l,r\)可用单调队列来快速求出。

需要注意区间\([l,x],[x,r]\)中,最好一个是满足\(h[i]<h[x]\),一个满足\(h[i]<=h[x]\),防止重复

#include<map>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=1e6+100;
struct SA{
    int sa[maxn],tp[maxn],rk[maxn],tax[maxn],h[maxn],n,m,st[maxn],top,l[maxn],r[maxn];
    char s[maxn];
    void Qsort(){
        for(int i=0;i<=m;i++) tax[i]=0;
        for(int i=1;i<=n;i++) tax[rk[i]]++;
        for(int i=1;i<=m;i++) tax[i]+=tax[i-1];
        for(int i=n;i>=1;i--)
            sa[tax[rk[tp[i]]]--]=tp[i];
    }
    void getsa(){
        m=200;
        for(int i=1;i<=n;i++)
            rk[i]=s[i],tp[i]=i;
        Qsort();
        for(int p=1,w=1;p<n;m=p,w<<=1){
            p=0;
            for(int i=1;i<=w;i++) tp[++p]=n+i-w;
            for(int i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
            Qsort();
            swap(tp,rk);
            rk[sa[1]]=p=1;
            for(int i=2;i<=n;i++)
                rk[sa[i]]=tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+w]==tp[sa[i-1]+w]?p:++p;
        }
    }
    void geth(){
        for(int i=1,j,p=0;i<=n;h[rk[i++]]=p)
        for(p?p--:p,j=sa[rk[i]-1];s[i+p]==s[j+p];p++);
    }
    ll work(){
        ll ans=0;
        for(int i=1;i<=n;i++) ans+=1ll*i*(n-1);
        h[0]=-0x7fffffff,top=0;
        for(int i=1;i<=n;i++){
            while(h[st[top]]>=h[i]) top--;
            l[i]=st[top]+1;
            st[++top]=i;
        }
        h[n+1]=-0x7fffffff,top=0,st[top]=n+1;
        for(int i=n;i>=1;i--){
            while(h[st[top]]>h[i]) top--;
            r[i]=st[top]-1;
            st[++top]=i;
        }
        for(int i=1;i<=n;i++)
            ans-=2ll*(i-l[i]+1)*(r[i]-i+1)*h[i];
        return ans;
    }
}sa;
int main(){
//	freopen(".in","r",stdin);
    scanf("%s",sa.s+1),sa.n=strlen(sa.s+1);
    sa.getsa(),sa.geth();
    printf("%lld\n",sa.work());
    return 0;
}

posted @ 2018-11-29 22:19  nianheng  阅读(108)  评论(0编辑  收藏  举报