bzoj4556: [Tjoi2016&Heoi2016]字符串 (后缀数组加主席树)

题目是给出一个字符串,每次询问一个区间[a,b]中所有的子串和另一个区间[c,d]的lcp最大值,首先求出后缀数组,对于lcp的最大值肯定是rank[c]的前驱和后继,但是对于这个题会出现问题,就是题目中有区间的限制。

For example: 

5 1

aaaab

1 2 3 5

对于这个样例,如果直接找到aab的前驱是 aaab,然后由于区间的原因答案是1,但是如果我们再往前找的话,找到aaaab,答案会变成2。那就出现了错误。考虑一下怎么做可以去除这种影响呢?

我们可以二分一下,首先对于[a,b]这个区间我们只考虑前一半区间[a,mid] 我们利用上面求前驱和后继的方法求出一个答案,如果答案大于mid我们就要考虑mid以前的区间,反之就考虑mid以后的区间。(这里的区间是指子串的左端点所在的区间)

为什么是对的?分类讨论一下,如果答案大于mid,那么我们不可能会考虑mid之后的区间,因为显然mid之后的区间到b的长度都小于mid,答案不会变优。如果答案小于mid,我们就不会在考虑mid之前的区间,因为算出答案小于mid,显然它不可能是由于区间限制得出的,因为我们只考虑前一半的区间,它到末尾的长度肯定大于等于mid,所以我们往后二分,因为随着区间的扩大我们找到的前驱和后继肯定会越来越接近rank[c],那么答案可能会变优,而我们往前判断答案肯定不会变优,所以往后二分。这样二分的正确性就ok了。(其实这个性质挺不好想的,我也是懵逼了半天,然而大奕哥用一个暴力碾掉了标算,bzoj上rank1......详情可以看博客 http://www.cnblogs.com/nbwzyzngyl/p/8215412.html  纯属侥幸......)

那么我们可以用主席树来办到求前驱和后继,这道题就解决了。—— by VANE

#include<bits/stdc++.h>
using namespace std;
const int N=100010;
const int M=4000005;
int n,m,a[N],tax[N],tp[N],height[N],st[N][20],Log[N];
int sa[N],rank[N],sum[M],ls[M],rs[M],siz,root[N],mm;
char s[N];
void rsort()
{
    for(int i=1;i<=m;++i) tax[i]=0;
    for(int i=1;i<=n;++i) tax[rank[tp[i]]]++;
    for(int i=1;i<=m;++i) tax[i]+=tax[i-1];
    for(int i=n;i;--i) sa[tax[rank[tp[i]]]--]=tp[i];
}
bool cmp(int *f,int x,int y,int w){return max(x,y)+w<=n&&f[x]==f[y]&&f[x+w]==f[y+w];}
void suffix()
{
    m=127;for(int i=1;i<=n;++i) tp[i]=i;
    for(int i=1;i<=n;++i) rank[i]=a[i];
    rsort();
    for(int w=1,p;w<n;w+=w)
    {
        p=0;for(int i=n-w+1;i<=n;++i) tp[++p]=i;
        for(int i=1;i<=n;++i) if(sa[i]>w) tp[++p]=sa[i]-w;
        rsort();swap(rank,tp);rank[sa[1]]=p=1;
        for(int i=2;i<=n;++i) rank[sa[i]]=cmp(tp,sa[i],sa[i-1],w)?p:++p;
        m=p;if(p>=n) break;
    }
    int j,k=0;
    for(int i=1;i<=n;height[rank[i++]]=k)
        for(k=k?k-1:0,j=sa[rank[i]-1];a[j+k]==a[i+k];++k);
    memset(st,127,sizeof st);
    for(int i=1;i<=n;++i) st[i][0]=height[i];
    for(int j=1;j<=20;++j)
        for(int i=1;i+(1<<j)-1<=n;++i)
            st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1]);
    for(int i=2,j=0;i<=n;++i)
    {
        if(i==(1<<j+1)) ++j;
        Log[i]=j;
    }
}
int lcp(int x,int y)
{
    if(x>y) swap(x,y);
    int len=Log[y-x+1];
    int ret=min(st[x][len],st[y-(1<<len)+1][len]);
    return ret;
}
void insert(int l,int r,int x,int &y,int z)
{
    y=++siz;
    sum[y]=sum[x]+1;
    ls[y]=ls[x];rs[y]=rs[x];
    if(l==r) return;
    int mid=l+r>>1;
    if(z<=mid) insert(l,mid,ls[x],ls[y],z);
    else insert(mid+1,r,rs[x],rs[y],z);
}
int query_sum(int l,int r,int x,int y,int le,int ri)
{
    if(le>ri) return 0;
    if(l==le&&r==ri) return sum[y]-sum[x];
    int mid=l+r>>1;
    if(ri<=mid) return query_sum(l,mid,ls[x],ls[y],le,ri);
    if(le>mid) return query_sum(mid+1,r,rs[x],rs[y],le,ri);
    return query_sum(l,mid,ls[x],ls[y],le,mid)+query_sum(mid+1,r,rs[x],rs[y],mid+1,ri);
}
int query_pos(int l,int r,int x,int y,int p)
{
    if(l==r) return l;
    int mid=l+r>>1;
    if(sum[ls[y]]-sum[ls[x]]>=p) return query_pos(l,mid,ls[x],ls[y],p);
    return query_pos(mid+1,r,rs[x],rs[y],p-sum[ls[y]]+sum[ls[x]]);
}
int main()
{
    //freopen("str.in","r",stdin);
    //freopen("str.out","w",stdout);
    scanf("%d%d",&n,&mm);
    scanf("%s",s+1);for(int i=1;i<=n;++i) a[i]=s[i];
    suffix();
    int a,b,c,d,x,pre,sub;
    for(int i=1;i<=n;++i) insert(1,n,root[i-1],root[i],rank[i]);
    while(mm--)
    {
        scanf("%d%d%d%d",&a,&b,&c,&d);
        int l=0,r=b-a+1,maxn=0,mid;
        bool flag=0;
        while(l<=r)
        {
            mid=l+r>>1;
            x=query_sum(1,n,root[a-1],root[b-mid],1,rank[c]-1);
            pre=x?query_pos(1,n,root[a-1],root[b-mid],x):-1;
            x=query_sum(1,n,root[a-1],root[b-mid],1,rank[c]);
            sub=query_pos(1,n,root[a-1],root[b-mid],1+x);
            if(x==sum[root[b-mid]]-sum[root[a-1]]) sub=-1;
            int ans=0;
            if(pre!=-1)
            {
                int o=min(lcp(pre+1,rank[c]),b-sa[pre]+1);
                ans=max(ans,min(o,d-c+1));
            }
            if(sub!=-1)
            {
                int o=min(lcp(rank[c]+1,sub),b-sa[sub]+1);
                ans=max(ans,min(o,d-c+1));
            }
            if(c>=a&&c<=b) ans=max(ans,min(d,b)-c+1);
            maxn=max(maxn,ans);
            if(ans>=mid) l=mid+1;
            else r=mid;
            if(flag) break;
            if(l==r) flag=1;
        }
        printf("%d\n",maxn);
    }
}

 

posted @ 2018-01-10 17:41  大奕哥&VANE  阅读(267)  评论(0编辑  收藏  举报