BZOJ3473: 字符串
3473: 字符串
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 109 Solved: 47
[Submit][Status]
Description
给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
Input
第一行两个整数n,k。
接下来n行每行一个字符串。
Output
一行n个整数,第i个整数表示第i个字符串的答案。
Sample Input
3 1
abc
a
ab
abc
a
ab
Sample Output
6 1 3
HINT
对于 100% 的数据,1<=n,k<=10^5,所有字符串总长不超过10^5,字符串只包含小写字母。
Source
题解:
神题一道。。。
继续搬运题解:by云神
首先将所有字符串串在一次做SA,然后我们对于sa上,枚举每个串的每个后缀,求出有几个该后缀的前缀符合条件,那么就要判定区间里面有多少个不同的数,所幸的是这里只需要求是否该数目>=k,所以对于每个位置记录个L(x),表示[L(x),x]中刚好有k个不同的数,且L(x)最大(参考了CF官方题解),然后CF上的题解是对于每个后缀二分出长度,然后是O(n log^2 n的算法),但是O(n log^2 n)在本题仍然会TLE,那么我们发现枚举后缀的时候,如果后缀c+S有n个前缀合法(c表示一个字符,s表示一个串),那么对于后缀S,至少有n-1个前缀合法(如果c+S有n个前缀出现不小于k次,那么其子串也是),那么我们就用类似求SA里的height一样的方法,记录一下前面的后缀的合法前缀数,然后这样的总复杂度就成了均摊O(n log n),可以AC。
一些注释写在代码里
代码:
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cmath> 4 #include<cstring> 5 #include<algorithm> 6 #include<iostream> 7 #include<vector> 8 #include<map> 9 #include<set> 10 #include<queue> 11 #include<string> 12 #define inf 1000000000 13 #define maxn 250000+5 14 #define maxm 500+100 15 #define eps 1e-10 16 #define pa pair<int,int> 17 #define for0(i,n) for(int i=0;i<=(n);i++) 18 #define for1(i,n) for(int i=1;i<=(n);i++) 19 #define for2(i,x,y) for(int i=(x);i<=(y);i++) 20 #define for3(i,x,y) for(int i=(x);i>=(y);i--) 21 #define mod 1000000007 22 using namespace std; 23 inline int read() 24 { 25 int x=0,f=1;char ch=getchar(); 26 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 27 while(ch>='0'&&ch<='9'){x=10*x+ch-'0';ch=getchar();} 28 return x*f; 29 } 30 int n,m,q,c[maxn],t1[maxn],t2[maxn],sa[maxn],rk[maxn],h[maxn]; 31 int st[maxn][20],rec[maxn],cnt[maxn],num[maxn],beg[maxn],end[maxn]; 32 char s[maxn]; 33 void getsa(int m) 34 { 35 int *x=t1,*y=t2; 36 for0(i,m)c[i]=0; 37 for0(i,n)c[x[i]=s[i]]++; 38 for1(i,m)c[i]+=c[i-1]; 39 for3(i,n,0)sa[--c[x[i]]]=i; 40 for(int k=1;k<=n+1;k<<=1) 41 { 42 int p=0; 43 for2(i,n-k+1,n)y[p++]=i; 44 for0(i,n)if(sa[i]>=k)y[p++]=sa[i]-k; 45 for0(i,m)c[i]=0; 46 for0(i,n)c[x[y[i]]]++; 47 for1(i,m)c[i]+=c[i-1]; 48 for3(i,n,0)sa[--c[x[y[i]]]]=y[i]; 49 swap(x,y);p=0;x[sa[0]]=0; 50 for1(i,n)x[sa[i]]=y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p; 51 if(p>=n)break; 52 m=p; 53 } 54 for1(i,n)rk[sa[i]]=i; 55 for(int i=0,k=0,j;i<n;h[rk[i++]]=k) 56 for(k?k--:0,j=sa[rk[i]-1];s[i+k]==s[j+k];k++); 57 } 58 void getst() 59 { 60 for1(i,n)st[i][0]=h[i]; 61 int k=log2(n); 62 for1(i,k)for1(j,n-(1<<i)+1)st[j][i]=min(st[j][i-1],st[j+(1<<(i-1))][i-1]); 63 } 64 inline int rmq(int x,int y)//求rmq 65 { 66 int k=log2(y-x+1); 67 return min(st[x][k],st[y-(1<<k)+1][k]); 68 } 69 inline bool check(int x,int y)//二分出S[x...x+y-1]在整个串中的左右端点 70 { 71 int l,r,mid,ll,rr; 72 if(h[x+1]<y)rr=x; 73 else 74 { 75 l=x+1;r=n; 76 while(l<=r) 77 { 78 mid=(l+r)>>1; 79 if(rmq(x+1,mid)>=y)l=mid+1;else r=mid-1; 80 } 81 rr=r; 82 } 83 if(h[x]<y)ll=x; 84 else 85 { 86 l=1;r=x-1; 87 while(l<=r) 88 { 89 mid=(l+r)>>1; 90 //if(x==37)cout<<l<<' '<<mid<<' '<<r<<' '<<rmq(mid+1,x)<<endl; 91 if(rmq(mid+1,x)>=y)r=mid-1;else l=mid+1; 92 } 93 ll=l; 94 } 95 //if(x==37)cout<<y<<' '<<ll<<' '<<rr<<' '<<rec[rr]<<' '<<ll<<endl; 96 return rec[rr]>=ll;//判断这个范围内是否有k个不同的num值,即出现在不同的k个串中 97 } 98 int main() 99 { 100 freopen("input.txt","r",stdin); 101 freopen("output.txt","w",stdout); 102 m=read();q=read();n=-1; 103 for1(i,m) 104 { 105 n++;beg[i]=n; 106 scanf("%s",s+n); 107 n=strlen(s);s[n]=' ';end[i]=n-1; 108 } 109 //printf("%s\n",s); 110 getsa(128); 111 getst(); 112 for1(i,m)for2(j,beg[i],end[i])num[j]=i;//标记所属 113 int t=1,k=0; 114 for1(i,n)if(num[sa[i]])//不能是空字符 115 { 116 if(!cnt[num[sa[i]]])k++; 117 cnt[num[sa[i]]]++; 118 if(k>=q) 119 { 120 for(;k-(cnt[num[sa[t]]]==1)>=q;k-=(cnt[num[sa[t]]]==1),--cnt[num[sa[t++]]]); 121 rec[i]=t; 122 } 123 } 124 /*for1(i,n) 125 { 126 cout<<i<<' '<<h[i]<<' '; 127 for2(j,sa[i],n)cout<<s[j]; 128 cout<<endl; 129 }*/ 130 for1(i,m) 131 { 132 long long ans=0;int k=0; 133 for2(j,beg[i],end[i]) 134 { 135 for(k?k--:0;k+1<=end[i]-j+1&&check(rk[j],k+1);k++);//类似于height数组的求法 136 //if(i==1&&s[j]=='b')cout<<"AAAAAAA"<<' '<<k<<' '<<rk[j]<<endl; 137 ans+=(long long)k; 138 } 139 printf("%lld",ans); 140 if(i!=m)printf(" "); 141 } 142 return 0; 143 }