【bzoj3473】字符串 【后缀自动机+树状数组】
题意:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?(本质相同重复计算)
题解:首先我们把这n个字符串的广义后缀自动机建立出来,然后处理出每个状态出现在n个串的多少个之中。接着把每个串在后缀自动机跑一遍,统计即可。
如何处理出每个状态出现在n个串的多少个之中?
如果一个状态x出现在某个串y中,那么fail[x]一定也出现在y中,因为fail[x]是x的一个后缀。设val[i]代表状态i来自于哪个串。所以如果我们把fail链倒过来建一棵树,状态x出现在n个串之中的个数就是x的子树中的val的不同个数。我们就可以把这棵树dfs一次,处理出每个节点的dfs序区间,就把这个问题转化为了查询一个区间有多少个不同的数字,跟HH的项链那题一模一样。用树状数组处理一下就好了。
如何统计?
设ans[x]表示x出现在n个串之中的个数。只需要对每个串在后缀自动机上走,如果ans[now]小于k的话now就不停地跳fail。这就相当于把当前匹配到的不停地截短。然后答案累加上len[now]即可。至于为什么,请读者自行思考。而且匹配的过程中,不会失配,这也很显然。
时间复杂度: n log n
代码实现:
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<vector>
using namespace std;
const int N=100005;
int n,k,l;
string s[N];
char str[N];
bool cmp(int a,int b);
struct SAM{
int last,tot,len[N*2],fail[N*2],val[N*2],ch[N*2][26],c[N*2],a[N*2];
int idx,in[N*2],out[N*2],pos[N*2],nxt[N*2],ck[N*2],ans[N*2];
vector<int> e[N*2];
SAM(){
last=tot=1;
}
void insert(int x,int id){
int p=last,np=++tot;
len[np]=len[p]+1;
last=np;
val[np]=id;
for(;p&&!ch[p][x];p=fail[p]){
ch[p][x]=np;
}
if(!p){
fail[np]=1;
}else{
int q=ch[p][x];
if(len[q]==len[p]+1){
fail[np]=q;
}else{
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fail[nq]=fail[q];
fail[q]=fail[np]=nq;
for(;p&&ch[p][x]==q;p=fail[p]){
ch[p][x]=nq;
}
}
}
}
void dfs(int u){
in[u]=++idx;
pos[idx]=u;
for(int i=0;i<e[u].size();i++){
dfs(e[u][i]);
}
out[u]=idx;
}
int lowbit(int x){
return x&(-x);
}
void add(int i){
while(i<=tot){
c[i]++;
i+=lowbit(i);
}
}
int sum(int i){
int res=0;
while(i){
res+=c[i];
i-=lowbit(i);
}
return res;
}
void build(){
for(int i=2;i<=tot;i++){
e[fail[i]].push_back(i);
}
dfs(1);
for(int i=1;i<=tot;i++){
a[i]=i;
}
sort(a+1,a+tot+1,cmp);
for(int i=tot;i>=1;i--){
if(val[pos[i]]){
nxt[i]=ck[val[pos[i]]];
ck[val[pos[i]]]=i;
}
}
for(int i=1;i<=tot;i++){
if(ck[i]){
add(ck[i]);
}
}
for(int i=1,j=1;i<=tot;i++){
while(j<in[a[i]]){
if(nxt[j]){
add(nxt[j]);
}
j++;
}
ans[a[i]]=sum(out[a[i]])-sum(in[a[i]]-1);
}
}
long long query(const char *s,int l){
long long res=0;
int now=1;
for(int i=0;i<l;i++){
now=ch[now][s[i]-'a'];
while(now&&ans[now]<k){
now=fail[now];
}
if(!now){
now=1;
continue;
}
res+=len[now];
}
return res;
}
}sam;
bool cmp(int a,int b){
return sam.in[a]==sam.in[b]?sam.out[a]<sam.out[b]:sam.in[a]<sam.in[b];
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++){
scanf("%s",str);
s[i]=str;
l=strlen(str);
sam.last=1;
for(int j=0;j<l;j++){
sam.insert(str[j]-'a',i);
}
}
sam.build();
for(int i=1;i<=n;i++){
printf("%lld ",sam.query(s[i].c_str(),s[i].size()));
}
puts("");
return 0;
}