【文文殿下】[BZOJ3277] 串
Description
字符串是oi界常考的问题。现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中
至少k个字符串的子串(注意包括本身)
Input
第一行两个整数n,k。
接下来n行每行一个字符串。
n,k,l<=100000
Output
输出一行n个整数,第i个整数表示第i个字符串的答案。
Sample Input
3 1
abc
a
ab
Sample Output
6 1 3
题解
多个字符串,考虑建广义后缀自动机。
对于每个节点,记录它在每个字符串出现的次数。
但是为了防止重复记录(一个字符串在他的SAM上可能多次匹配到同一个点),我们对每个节点记录一个“上一次统计到的是哪个字符串”,这样子进行有序增加。
当这个字符串跑到某一个节点时,我们把从这个节点,沿着parent树直到根部全部增加。因为一个串在匹配的时候,有可能绕过了他的parent直接匹配到该节点。
然后,对于每一个串,再跑一边。当他的右端点匹配到某个状态以后,立即进行“清算”,把它能够匹配的后缀全部统计。即:沿着Parent树全部加入答案。
但是这样是n^2的,我们预处理前缀和可以做到O(n)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
typedef long long ll;
const int maxn = 2e5+20;
ll ans=0;
int tmp[maxn];
int par[maxn],mx[maxn],tr[maxn][26];
char A[maxn>>1];
int f[maxn];
int cnt = 1,last = 1,Right[maxn];
int c[maxn],id[maxn];
int pre[maxn];
std::string S[maxn];
int n,k;
void extend(int x) {
int p = last;
if(tr[p][x]&&mx[tr[p][x]]==mx[p]+1) {last=tr[p][x];return;}
int np=++cnt;
mx[np]=mx[p]+1;
while(p&&!tr[p][x]) tr[p][x]=np,p=par[p];
if(!p) par[np]=1;
else {
int q = tr[p][x];
if(mx[q]==mx[p]+1) par[np]=q;
else {
int nq = ++cnt;
mx[nq]=mx[p]+1;
memcpy(tr[nq],tr[q],sizeof tr[q]);
par[nq]=par[q];
par[q]=par[np]=nq;
while(p&&tr[p][x]==q) tr[p][x]=nq,p=par[p];
}
}
last = np;
return;
}
inline void topsort() {
for(int i = 1;i<=cnt;++i) ++c[mx[i]];
for(int i = 1;i<=cnt;++i) c[i]+=c[i-1];
for(int i = 1;i<=cnt;++i) id[c[mx[i]]--]=i;
return;
}
int main() {
scanf("%d%d",&n,&k);
for(int i = 1;i<=n;++i) {
scanf("%s",A);
S[i]=std::string(A);
int len = S[i].length();
last = 1;
for(int j = 0;j<len;++j) extend(S[i][j]-'a');
}
for(int i = 1;i<=n;++i) {
int len = S[i].length();
int cur = 1;
for(int j = 0;j<len;++j) {
int c = S[i][j]-'a';
while(cur&&!tr[cur][c]) cur=par[cur];
if(!cur) cur=1;
else cur = tr[cur][c];
int tmp = cur;
if(tmp&&pre[tmp]!=i) pre[tmp]=i,++Right[tmp],tmp=par[tmp];
}
}
topsort();
//Right[1]=0;
for(int i = 1;i<=cnt;++i) f[id[i]]=f[par[id[i]]]+(Right[id[i]]>=k?mx[id[i]]-mx[par[id[i]]]:0);
for(int i = 1;i<=n;++i) {
int len = S[i].length();
int cur=1,L=0;
ll ans = 0;
for(int j = 0;j<len;++j) {
int c = S[i][j]-'a';
while(cur&&!tr[cur][c]) cur=par[cur];
if(!cur) cur=1,L=0;
else L=std::min(L,mx[cur])+1,cur = tr[cur][c];
if(L==mx[cur]||Right[cur]<k)
ans+=f[cur];
else {
ans+=f[cur];
ans-=mx[cur];
ans+=L;
}
}
printf("%lld ",ans);
}
return 0;
}