后缀数组,看到网上很多题解都是单调栈,这里提供一个不是单调栈的做法,
首先将两个串 连接起来求height 求完之后按height值从大往小合并。 height值代表的是 sa[i]和sa[i-1] 的公共前缀长度,那么每次合并就是合并 i和i-1 那么在合并小的时候公共前缀更大的肯定已经都合并在一起,那么就可以直接统计了。
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<vector> #include<queue> #include<map> #include<set> #include<time.h> #include<string> #define cl(a,b) memset(a,b,sizeof(a)) #define max(x,y) ((x)>(y)?(x):(y)) #define min(x,y) ((x)<(y)?(x):(y)) #define REP(i,n) for(int i=0;i<n;++i) #define REP1(i,a,b) for(int i=a;i<=b;++i) #define REP2(i,a,b) for(int i=a;i>=b;--i) #define MP make_pair #define LL long long #define ULL unsigned long long #define X first #define Y second using namespace std; const int MAXN = 200050; struct SuffixArray{ int wa[MAXN]; int wb[MAXN]; int wv[MAXN]; int ws[MAXN]; int sa[MAXN]; int rank[MAXN]; int height[MAXN]; int r[MAXN]; int n; int m; void input(int *val, int len, int Max){ for (int i = 0;i < len;i++) r[i] = val[i]; r[len] = 0; n = len; m = Max; calSa(); calHeight(); } int cmp(int *r, int a, int b, int l){ return (r[a] == r[b] && r[a + l] == r[b + l]); } void calSa(){ int i, j, p, *x = wa, *y = wb, *t; for (i = 0;i < m;i++) ws[i] = 0; for (i = 0;i < n + 1;i++) ws[x[i] = r[i]]++; for (i = 1;i < m;i++) ws[i] += ws[i - 1]; for (i = n;i >= 0;i--) sa[--ws[x[i]]] = i; for (j = 1, p = 1;p < n + 1;j *= 2, m = p){ for (p = 0, i = n - j + 1;i < n + 1;i++) y[p++] = i; for (i = 0;i < n + 1;i++) if (sa[i] >= j) y[p++] = sa[i] - j; for (i = 0;i < n + 1;i++) wv[i] = x[y[i]]; for (i = 0;i < m;i++) ws[i] = 0; for (i = 0;i < n + 1;i++) ws[wv[i]]++; for (i = 1;i < m;i++) ws[i] += ws[i - 1]; for (i = n;i >= 0;i--) sa[--ws[wv[i]]] = y[i]; for (t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n + 1;i++) x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++; } } void calHeight(){ int i, j, k = 0; for (i = 1;i <= n;i++) rank[sa[i]] = i; for (i = 0;i < n;height[rank[i++]] = k) for (k?k--:0, j = sa[rank[i]- 1];r[i + k] == r[j + k];k++); } }SA; char s1[MAXN],s2[MAXN]; int s[MAXN]; vector<int>e[MAXN]; int id[MAXN]; int x[MAXN],y[MAXN]; int fa[MAXN]; int getfa(int x) { if(fa[x]==x)return x; else return fa[x]=getfa(fa[x]); } int main() { int k; while(scanf("%d",&k)&&k) { int h=0,len1,len2; scanf(" %s %s",s1,s2); len1=strlen(s1); len2=strlen(s2); for(int i=0;i<len1;++i) { id[h]=0; s[h++]=s1[i]; } s[h++]=1; for(int i=0;i<len2;++i) { id[h]=1; s[h++]=s2[i]; } SA.input(s,h,500); // for(int i=0;i<=h;++i) // printf("%d %d %d\n",i,SA.sa[i],SA.height[i]); for(int i=0;i<=h;++i) e[i].clear(); for(int i=1;i<=h;++i) { e[SA.height[i]].push_back(i); } for(int i=0;i<=h;++i) { fa[i]=i; if(id[SA.sa[i]]==0)//这里要用sa来判断原先属于哪个串 x[i]=1,y[i]=0; else x[i]=0,y[i]=1; } LL ans=0; for(int i=h;i>=k;--i) { for(int j=0;j<e[i].size();++j) { int u=e[i][j]; int f1=getfa(u); int f2=getfa(u-1); // printf("%d %d\n",f1,f2); if(f1!=f2){ ans-=(LL)x[f2]*y[f2]*(i-k+1);//减去原先的贡献值 ans-=(LL)x[f1]*y[f1]*(i-k+1);//减去原先的贡献值 fa[f1]=f2; x[f2]+=x[f1]; y[f2]+=y[f1]; ans+=(LL)x[f2]*y[f2]*(i-k+1);//加上新的 } } } printf("%I64d\n",ans); } }