题目大意:

给定A,B两种字符串,问他们当中的长度大于k的公共子串的个数有多少个

 

这道题目本身理解不难,将两个字符串合并后求出它的后缀数组

然后利用后缀数组求解答案

这里一开始看题解说要用栈的思想,觉得很麻烦就不做了,后来在比赛中又遇到就后悔了,到今天看了很久才算看懂

首先建一个栈,从栈底到栈顶都保证是单调递增的

我们用一个tot记录当前栈中所有项和一个刚进入的子串匹配所能得到的总的子串的数目(当然前提是,当前进入的子串height值比栈顶还大,那么和栈中任意一个子串匹配都保持当前栈中记录的那时候入栈的height值)

但是若height不比栈顶大,说明从栈顶开始到刚好比它小的这一段tot有多加的部分,这部分就是height值多出来的那块,然后把这部分都视作height值为当前的height值,因为后面子串进入,它的height值总是取决于那段区间的最小值,所以不会产生影响,这样就可以把所有比当前height大的都弹出栈,这样就达到了O(n)的复杂度

这里用q[][]手写栈

q[i][0]表示栈中第i号元素记录时候的height值,q[i][1]表示在这个height值上覆盖了q[i][1]个子串

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <iostream>
  4 using namespace std;
  5 #define INF 0x3f3f3f3f
  6 #define ll long long
  7 const int MAXN = 2*100010;
  8 int sa[MAXN] , rank[MAXN] , height[MAXN];
  9 int wa[MAXN] , wb[MAXN] , wsf[MAXN] , wv[MAXN];
 10 int a[MAXN] , k;
 11 char str1[MAXN] , str2[MAXN];
 12 int q[MAXN][2];
 13 
 14 int cmp(int *r , int a , int b , int l)
 15 {
 16     return r[a]==r[b] && r[a+l]==r[b+l];
 17 }
 18 
 19 void getSa(int *r , int *sa , int n , int m)
 20 {
 21     int *x = wa , *y = wb , *t;
 22     for(int i=0 ; i<m ; i++) wsf[i]=0;
 23     for(int i=0 ; i<n ; i++) wsf[x[i]=r[i]]++;
 24     for(int i=1 ; i<m ; i++) wsf[i]+=wsf[i-1];
 25     for(int i=n-1 ; i>=0 ; i--) sa[--wsf[x[i]]] = i;
 26 
 27     int i,j,p=1;
 28     for(j=1 ; p<n ; j*=2 , m=p)
 29     {
 30         for(p=0 , i=n-j ; i<n ; i++) y[p++] = i;
 31         for(i=0 ; i<n ; i++) if(sa[i]>=j) y[p++] = sa[i]-j;
 32 
 33         for(i=0 ; i<n ; i++) wv[i]=x[y[i]];
 34         for(i=0 ; i<m ; i++) wsf[i]=0;
 35         for(i=0 ; i<n ; i++) wsf[wv[i]]++;
 36         for(i=1 ; i<m ; i++) wsf[i]+=wsf[i-1];
 37         for(i=n-1 ; i>=0 ; i--) sa[--wsf[wv[i]]] = y[i];
 38 
 39         for(t=x , x=y , y=t , x[sa[0]]=0 , p=1 , i=1; i<n ; i++)
 40             x[sa[i]] = cmp(y , sa[i-1] , sa[i] , j)?p-1:p++;
 41     }
 42     return ;
 43 }
 44 
 45 void callHeight(int *r , int *sa , int n)
 46 {
 47     for(int i=0 ; i<=n ; i++) rank[sa[i]]=i;
 48     int i , j , k=0;
 49     for(i=0 ; i<n ; height[rank[i++]]=k)
 50         for(j=sa[rank[i]-1] , k?k--:0 ; r[i+k]==r[j+k] ; k++) ;
 51     return;
 52 }
 53 
 54 ll solve(int len1 , int len2)
 55 {
 56     ll ans = 0;
 57     //B串中的子串不断匹配rank比其高的A子串
 58     int top = 0;
 59     ll tot =0 , cnt = 0;
 60     for(int i=1 ; i<=len1+len2+1 ; i++){
 61         if(height[i]<k){
 62             top = tot = 0;
 63             continue;
 64         }
 65         cnt = 0;
 66         if(sa[i-1]<len1){
 67             cnt ++;
 68             tot += height[i]-k+1;
 69         }
 70         while(top&&height[i]<=q[top][0]){
 71             tot -= q[top][1]*(q[top][0]-height[i]);
 72             cnt += q[top][1];
 73             top--;
 74         }
 75         q[++top][0] = height[i];
 76         q[top][1] = cnt;
 77         if(sa[i]>len1) ans+=tot;
 78     }
 79     //A串中的子串不断匹配rank比其高的B子串
 80     tot = top = 0;
 81     for(int i=1 ; i<=len1+len2+1 ; i++){
 82         if(height[i]<k){
 83             top = tot = 0;
 84             continue;
 85         }
 86         cnt = 0;
 87         if(sa[i-1]>len1){
 88             cnt ++;
 89             tot += height[i]-k+1;
 90         }
 91         while(top&&height[i]<=q[top][0]){
 92             tot -= q[top][1]*(q[top][0]-height[i]);
 93             cnt += q[top][1];
 94             top--;
 95         }
 96         q[++top][0] = height[i];
 97         q[top][1] = cnt;
 98         if(sa[i]<len1) ans+=tot;
 99     }
100     return ans;
101 }
102 
103 int main()
104 {
105    // freopen("a.in" , "r" , stdin);
106 
107     while(scanf("%d" , &k) , k)
108     {
109         scanf("%s%s" , str1 , str2);
110         int len1 = strlen(str1) , len2 = strlen(str2);
111         for(int i=0 ; i<len1 ; i++) a[i] = (int)str1[i];
112         a[len1] = 259;
113         for(int i=0 ; i<len2 ; i++) a[i+len1+1] = (int)str2[i];
114         a[len1+len2+1] = 0;
115 
116         getSa(a , sa , len1+len2+2 , 260);
117         callHeight(a , sa , len1+len2+1);
118 
119       //  for(int i=0 ; i<len1+len2+2 ; i++) cout<<"rank i: "<<i<<" "<<rank[i]<<endl;
120       //  for(int i=1 ; i<len1+len2+2 ; i++) cout<<"xixi: "<<height[i]<<endl;
121         ll ans = solve(len1 , len2);
122         printf("%I64d\n" , ans);
123     }
124     return 0;
125 }

 

 posted on 2015-05-20 01:26  Love风吟  阅读(999)  评论(0编辑  收藏  举报