bzoj4556: [Tjoi2016&Heoi2016]字符串 (后缀数组加主席树)
题目是给出一个字符串,每次询问一个区间[a,b]中所有的子串和另一个区间[c,d]的lcp最大值,首先求出后缀数组,对于lcp的最大值肯定是rank[c]的前驱和后继,但是对于这个题会出现问题,就是题目中有区间的限制。
For example:
5 1
aaaab
1 2 3 5
对于这个样例,如果直接找到aab的前驱是 aaab,然后由于区间的原因答案是1,但是如果我们再往前找的话,找到aaaab,答案会变成2。那就出现了错误。考虑一下怎么做可以去除这种影响呢?
我们可以二分一下,首先对于[a,b]这个区间我们只考虑前一半区间[a,mid] 我们利用上面求前驱和后继的方法求出一个答案,如果答案大于mid我们就要考虑mid以前的区间,反之就考虑mid以后的区间。(这里的区间是指子串的左端点所在的区间)
为什么是对的?分类讨论一下,如果答案大于mid,那么我们不可能会考虑mid之后的区间,因为显然mid之后的区间到b的长度都小于mid,答案不会变优。如果答案小于mid,我们就不会在考虑mid之前的区间,因为算出答案小于mid,显然它不可能是由于区间限制得出的,因为我们只考虑前一半的区间,它到末尾的长度肯定大于等于mid,所以我们往后二分,因为随着区间的扩大我们找到的前驱和后继肯定会越来越接近rank[c],那么答案可能会变优,而我们往前判断答案肯定不会变优,所以往后二分。这样二分的正确性就ok了。(其实这个性质挺不好想的,我也是懵逼了半天,然而大奕哥用一个暴力碾掉了标算,bzoj上rank1......详情可以看博客 http://www.cnblogs.com/nbwzyzngyl/p/8215412.html 纯属侥幸......)
那么我们可以用主席树来办到求前驱和后继,这道题就解决了。—— by VANE
#include<bits/stdc++.h> using namespace std; const int N=100010; const int M=4000005; int n,m,a[N],tax[N],tp[N],height[N],st[N][20],Log[N]; int sa[N],rank[N],sum[M],ls[M],rs[M],siz,root[N],mm; char s[N]; void rsort() { for(int i=1;i<=m;++i) tax[i]=0; for(int i=1;i<=n;++i) tax[rank[tp[i]]]++; for(int i=1;i<=m;++i) tax[i]+=tax[i-1]; for(int i=n;i;--i) sa[tax[rank[tp[i]]]--]=tp[i]; } bool cmp(int *f,int x,int y,int w){return max(x,y)+w<=n&&f[x]==f[y]&&f[x+w]==f[y+w];} void suffix() { m=127;for(int i=1;i<=n;++i) tp[i]=i; for(int i=1;i<=n;++i) rank[i]=a[i]; rsort(); for(int w=1,p;w<n;w+=w) { p=0;for(int i=n-w+1;i<=n;++i) tp[++p]=i; for(int i=1;i<=n;++i) if(sa[i]>w) tp[++p]=sa[i]-w; rsort();swap(rank,tp);rank[sa[1]]=p=1; for(int i=2;i<=n;++i) rank[sa[i]]=cmp(tp,sa[i],sa[i-1],w)?p:++p; m=p;if(p>=n) break; } int j,k=0; for(int i=1;i<=n;height[rank[i++]]=k) for(k=k?k-1:0,j=sa[rank[i]-1];a[j+k]==a[i+k];++k); memset(st,127,sizeof st); for(int i=1;i<=n;++i) st[i][0]=height[i]; for(int j=1;j<=20;++j) for(int i=1;i+(1<<j)-1<=n;++i) st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1]); for(int i=2,j=0;i<=n;++i) { if(i==(1<<j+1)) ++j; Log[i]=j; } } int lcp(int x,int y) { if(x>y) swap(x,y); int len=Log[y-x+1]; int ret=min(st[x][len],st[y-(1<<len)+1][len]); return ret; } void insert(int l,int r,int x,int &y,int z) { y=++siz; sum[y]=sum[x]+1; ls[y]=ls[x];rs[y]=rs[x]; if(l==r) return; int mid=l+r>>1; if(z<=mid) insert(l,mid,ls[x],ls[y],z); else insert(mid+1,r,rs[x],rs[y],z); } int query_sum(int l,int r,int x,int y,int le,int ri) { if(le>ri) return 0; if(l==le&&r==ri) return sum[y]-sum[x]; int mid=l+r>>1; if(ri<=mid) return query_sum(l,mid,ls[x],ls[y],le,ri); if(le>mid) return query_sum(mid+1,r,rs[x],rs[y],le,ri); return query_sum(l,mid,ls[x],ls[y],le,mid)+query_sum(mid+1,r,rs[x],rs[y],mid+1,ri); } int query_pos(int l,int r,int x,int y,int p) { if(l==r) return l; int mid=l+r>>1; if(sum[ls[y]]-sum[ls[x]]>=p) return query_pos(l,mid,ls[x],ls[y],p); return query_pos(mid+1,r,rs[x],rs[y],p-sum[ls[y]]+sum[ls[x]]); } int main() { //freopen("str.in","r",stdin); //freopen("str.out","w",stdout); scanf("%d%d",&n,&mm); scanf("%s",s+1);for(int i=1;i<=n;++i) a[i]=s[i]; suffix(); int a,b,c,d,x,pre,sub; for(int i=1;i<=n;++i) insert(1,n,root[i-1],root[i],rank[i]); while(mm--) { scanf("%d%d%d%d",&a,&b,&c,&d); int l=0,r=b-a+1,maxn=0,mid; bool flag=0; while(l<=r) { mid=l+r>>1; x=query_sum(1,n,root[a-1],root[b-mid],1,rank[c]-1); pre=x?query_pos(1,n,root[a-1],root[b-mid],x):-1; x=query_sum(1,n,root[a-1],root[b-mid],1,rank[c]); sub=query_pos(1,n,root[a-1],root[b-mid],1+x); if(x==sum[root[b-mid]]-sum[root[a-1]]) sub=-1; int ans=0; if(pre!=-1) { int o=min(lcp(pre+1,rank[c]),b-sa[pre]+1); ans=max(ans,min(o,d-c+1)); } if(sub!=-1) { int o=min(lcp(rank[c]+1,sub),b-sa[sub]+1); ans=max(ans,min(o,d-c+1)); } if(c>=a&&c<=b) ans=max(ans,min(d,b)-c+1); maxn=max(maxn,ans); if(ans>=mid) l=mid+1; else r=mid; if(flag) break; if(l==r) flag=1; } printf("%d\n",maxn); } }