POJ 3415:后缀数组+单调栈优化
题意很简单,求两个字符串长度大于等于K的子串个数
一开始还是只会暴力。。发现n^2根本没法做。。。看了题解理解了半天才弄出来,太弱了。。。
思路:把两个字符串连接后做一次后缀数组,求出height
暴力的想法自然是枚举第一个子串的起始位置i和第二个子串的起始位置j,肯定会T的
看了题解才知道有单调栈这种有优化方法。。
将后缀分为A组(起始点为第一个字符串)、B组
设符合要求的lcp长度为a,则其对答案的贡献为a-k+1(长度为k~a的都是符合要求的)
一开始这里我也是有疑问的,比如说k=1,aa和aab两个suffix,lcp=2,这样的话贡献为2-1+1=2
但是我们可以看到有这两个suffix符合要求的是有3个的(长度1的2个,2的1个),我们似乎只统计了长度1的和长度2的各一个
确实如此。但是我们不能小范围的看问题,应该要结合整个sa[]和height[]
这里没统计不意味着之后不统计,还会有一个a后缀和一个ab后缀,这里又贡献了1.这样就补齐了~
我们知道两个suffix的lcp是height中对应段height[_rank[i]+1]~height[_rank[j]]的最小值,应用这个性质我们可以来维护单调递增的栈
为什么要这样呢?可以理解为栈里是可能被用到的候选序列,如果当前扫描到的height小于栈顶(候选最大值),则根据上面的性质,
可以得出大于height的值是无法做出贡献了(或者说贡献变小了),那累加器的值要更新
同时我们这里用到了一个优化,为了防止栈内元素过多,我们把同一个height值,捆绑一个个数num,这样可以提升统计效率
我们对A、B组各做一次,加起来就是答案
一开始我在这里没有完全理解,总感觉会重复统计。其实是没有的。为什么呢?
那统计B组的时候为例。
对于B组后缀j,我统计答案都是在sa[j]之前找。
比如说找到A组中的ii,jj,kk三个后缀是符合的,那必定有sa[ii]、sa[jj]、sa[kk]都小于sa[j]
所以在统计A组时,sa[ii]也是在sa[ii]之前找,不可能找到sa[j]
#include"cstdio" #include"queue" #include"cmath" #include"stack" #include"iostream" #include"algorithm" #include"cstring" #include"queue" #include"map" #include"set" #include"vector" #define LL long long #define mems(a,b) memset(a,b,sizeof(a)) #define ls pos<<1 #define rs pos<<1|1 #define max(a,b) (a)>(b)?(a):(b) using namespace std; const int MAXN = 200500; const int INF = 0x3f3f3f3f; char a[MAXN],b[MAXN]; int sa[MAXN],_rank[MAXN]; int wa[MAXN],wb[MAXN],wv[MAXN],Ws[MAXN]; int cmp(int *r,int a,int b,int l) {return r[a]==r[b]&&r[a+l]==r[b+l];} void get_sa(char *r,int *sa,int n,int m){ int i,j,p,*x=wa,*y=wb,*t; for(i=0;i<m;i++) Ws[i]=0; for(i=0;i<n;i++) Ws[x[i]=r[i]]++; for(i=1;i<m;i++) Ws[i]+=Ws[i-1]; for(i=n-1;i>=0;i--) sa[--Ws[x[i]]]=i; for(j=1,p=1;p<n;j*=2,m=p){ for(p=0,i=n-j;i<n;i++) y[p++]=i; for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j; for(i=0;i<n;i++) wv[i]=x[y[i]]; for(i=0;i<m;i++) Ws[i]=0; for(i=0;i<n;i++) Ws[wv[i]]++; for(i=1;i<m;i++) Ws[i]+=Ws[i-1]; for(i=n-1;i>=0;i--) sa[--Ws[wv[i]]]=y[i]; for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++) x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++; } return; } int height[MAXN]; void get_height(char *r,int *sa,int n){ int i,j,k=0; for(i=1;i<=n;i++) _rank[sa[i]]=i; for(i=0;i<n;height[_rank[i++]]=k) for(k?k--:0,j=sa[_rank[i]-1];r[i+k]==r[j+k];k++); return; } int k; int main(){ //freopen("in.txt","r",stdin); while(scanf("%d",&k)&&k){ scanf("%s%s",a,b); int la=strlen(a); int lb=strlen(b); a[la]='$'; for(int i=0;i<lb;i++) a[i+la+1]=b[i]; int n=la+lb; //最后一个元素的下标 a[++n]='\0'; get_sa(a,sa,n+1,300); get_height(a,sa,n); LL ans=0,cnt=0; //cnt为累加器 int top=0; pair<int,int> s[MAXN]; for(int i=1;i<=n;i++){ if(height[i]<k) top=0,cnt=0; else{ int num=0; //统计同一height值的个数 if(sa[i-1]<la) num++,cnt+=height[i]-k+1; while(top&&height[i]<=s[top].first){ cnt-=s[top].second*(s[top].first-height[i]); //如果栈顶元素的height大于等于当前height,则贡献会变小s[top].first-height[i],更新累加器 num+=s[top--].second; //更新个数 } s[++top]=make_pair(height[i],num); if(sa[i]>la) ans+=cnt; } } top=0; cnt=0; for(int i=1;i<=n;i++){ if(height[i]<k) top=0,cnt=0; else{ int num=0; if(sa[i-1]>la) num++,cnt+=height[i]-k+1; while(top&&height[i]<=s[top].first){ cnt-=s[top].second*(s[top].first-height[i]); num+=s[top--].second; } s[++top]=make_pair(height[i],num); if(sa[i]<la) ans+=cnt; } } printf("%lld\n",ans); } return 0; }