bzoj3473: 字符串 && bzoj3277串
3473: 字符串
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 121 Solved: 53
[Submit][Status][Discuss]
Description
给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
Input
第一行两个整数n,k。
接下来n行每行一个字符串。
Output
一行n个整数,第i个整数表示第i个字符串的答案。
Sample Input
3 1
abc
a
ab
abc
a
ab
Sample Output
6 1 3
HINT
对于 100% 的数据,1<=n,k<=10^5,所有字符串总长不超过10^5,字符串只包含小写字母。
Source
很久之前做的题今天一看竟然不会做了。。。于是补篇题解。
首先把所有串连起来做一遍SA,求出hight,然后在后缀数组上从前往后扫。
那么现在要求的就是当前这个后缀有多少前缀是至少k个串的子串,这些前缀一定是连续的一段,因为如果Sx出现了k次,那么S也一定出现了k次。
设当前位是i,我们现在拥有后缀数组上一位的答案lastans,那么把它与hight[i]取一个min得到x,那么这位的答案至少是x。
然后考虑这位新出现的子串,那些包含这些子串的位置一定在i下面,那么维护一个指针使当前区间内刚好包含k个不同串的任意一个后缀,当i++时指针往后扫。
那么指针的位置与i用ST表求个区间RMQ,用x与这个区间最小值取max就是当前位的答案。
复杂度$nlogn$
感觉后缀自动机的做法很不科学
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 300005 #define ll long long using namespace std; int n,k; int sa[N],rank[N],wb[N],sum[N],c[N]; void getsa(int n,int m) { int *x=rank,*y=wb; for(int i=0;i<m;i++)sum[i]=0; for(int i=0;i<n;i++)sum[x[i]=c[i]]++; for(int i=1;i<m;i++)sum[i]+=sum[i-1]; for(int i=n-1;i>=0;i--)sa[--sum[x[i]]]=i; int p=1; for(int j=1;p<n;j<<=1,m=p) { p=0; for(int i=n-j;i<n;i++)y[p++]=i; for(int i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j; for(int i=0;i<m;i++)sum[i]=0; for(int i=0;i<n;i++)sum[x[i]]++; for(int i=1;i<m;i++)sum[i]+=sum[i-1]; for(int i=n-1;i>=0;i--)sa[--sum[x[y[i]]]]=y[i]; swap(x,y);x[sa[0]]=0;p=1; for(int i=1;i<n;i++) x[sa[i]]=y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j]?p-1:p++; } } int h[N]; void calh(int n) { for(int i=1;i<=n;i++)rank[sa[i]]=i; int kk=0; for(int i=0;i<n;i++) { if(kk)kk--; int j=sa[rank[i]-1]; while(c[i+kk]==c[j+kk])kk++; h[rank[i]]=kk; } return ; } int mn[N][20],lg[N]; void ST() { lg[0]=-1; for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1; for(int i=1;i<=n;i++)mn[i][0]=h[i]; for(int i=1;i<=19;i++) { for(int j=1;j<=n;j++) { if(j+(1<<(i-1))<=n)mn[j][i]=min(mn[j][i-1],mn[j+(1<<(i-1))][i-1]); else mn[j][i]=mn[j][i-1]; } }return ; } int qur(int l,int r) { int k=lg[r-l+1]; return min(mn[l][k],mn[r-(1<<k)+1][k]); } int be[N]; ll ans[N]; int len[N],sz[N]; int now[N],nw; void solve() { int l=0;int tmp=0; for(int i=1;i<=n;i++) { if(i!=1&&be[sa[i-1]]!=0) { now[be[sa[i-1]]]--; if(!now[be[sa[i-1]]])nw--; } while(l!=n&&nw<k) { l++; if(be[sa[l]]!=0) { now[be[sa[l]]]++; if(now[be[sa[l]]]==1)nw++; } } tmp=min(tmp,h[i]); if(nw==k) { if(be[sa[i]]) { int num; if(l!=i)num=qur(i+1,l); else num=sz[sa[i]]; tmp=max(tmp,num); } } ans[be[sa[i]]]+=tmp; } return ; } char s[N]; int main() { int cnt; scanf("%d%d",&cnt,&k); int m=256;n=-1; for(int i=1;i<=cnt;i++) { scanf("%s",s+1);len[i]=strlen(s+1); for(int j=1;j<=len[i];j++) { c[++n]=s[j]; be[n]=i; sz[n]=len[i]-j+1; } if(i!=cnt)c[++n]=m++; }n++; getsa(n+1,m);calh(n); ST(); solve(); for(int i=1;i<cnt;i++)printf("%lld ",ans[i]); printf("%lld",ans[cnt]); return 0; }