[学习笔记]后缀自动机SAM

好抽象啊,早上看了两个多小时才看懂,\(\%\%\%Fading\) 早就懂了

讲解就算了吧……可以去看看其他人的博客

1、【模板】后缀自动机

\(siz\) 为该串出现的次数,\(l\) 为子串长度,每次乘一下就好了

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=2000000+10;
int n,a[maxn],c[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],siz[maxn];
char s[maxn];ll ans;

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
	siz[q]=1;
}

int main()
{
	scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
	for(int i=1;i<=n;i++) insert(s[i]-'a');
	for(int i=1;i<=cnt;i++) c[l[i]]++;
	for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
	for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
	for(int i=cnt;i>=1;i--){
		siz[fa[a[i]]]+=siz[a[i]];
		if(siz[a[i]]>1) ans=max(ans,1ll*siz[a[i]]*l[a[i]]);
	}
	printf("%lld\n",ans);
	return 0;
}

题意简述:\([TJOI2015]\) 弦论找第 \(k\) 小子串 \(t=0\) 的弱化版

直接像权值线段树找第 \(k\) 小就好了

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=200000+10;
int n,q,a[maxn],c[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],siz[maxn];
char s[maxn];

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
}

int main()
{
	scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
	for(int i=1;i<=n;i++) insert(s[i]-'a');
	for(int i=1;i<=cnt;i++) c[l[i]]++;
	for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
	for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
	for(int i=cnt;i>=1;i--){
		siz[a[i]]=1;
		for(int j=0;j<26;j++)
			if(ch[a[i]][j]) siz[a[i]]+=siz[ch[a[i]][j]];
	}
	scanf("%d",&q);
	int k,p;
	while(q--){
		scanf("%d",&k);p=1;
		while(k){
			for(int i=0;i<26;i++){
				if(!ch[p][i]) continue;
				if(siz[ch[p][i]]>=k){
					putchar(i+'a');
					k--;p=ch[p][i];break;
				}
				else k-=siz[ch[p][i]];
			}
		}
		putchar('\n');
	}
	return 0;
}

3、[TJOI2015]弦论

\(t=1\) 的时候就是像模板一样记录一下 \(siz\),拓扑一遍

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=1000000+10;
int n,q,t,a[maxn],c[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],siz[maxn],sum[maxn];
char s[maxn];

void insert(int c){
    int p=last,q=++cnt;last=q;l[q]=l[p]+1;
    for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
    if(!p) fa[q]=1;
    else {
        int r=ch[p][c];
        if(l[p]+1==l[r]) fa[q]=r;
        else {
            int s=++cnt;l[s]=l[p]+1;
            memcpy(ch[s],ch[r],sizeof(ch[r]));
            fa[s]=fa[r];fa[r]=fa[q]=s;
            for(;ch[p][c]==r;p=fa[p]) ch[p][c]=s;
        }
    }
    siz[q]=1;
}

int main()
{
    scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
    for(int i=1;i<=n;i++) insert(s[i]-'a');
    for(int i=1;i<=cnt;i++) c[l[i]]++;
    for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
    for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
    int k,p=1;
    scanf("%d%d",&t,&k);
    for(int i=cnt;i>=1;i--){
        if(t) siz[fa[a[i]]]+=siz[a[i]];
        else siz[a[i]]=1;
    }
    siz[1]=0;
    for(int i=cnt;i>=1;i--){
        sum[a[i]]=siz[a[i]];
        for(int j=0;j<26;j++)
            if(ch[a[i]][j]) sum[a[i]]+=sum[ch[a[i]][j]];
    }
    if(k>sum[1]){
        printf("-1\n");
        return 0;
    }
    while(k){
        for(int i=0;i<26;i++){
            if(!ch[p][i]) continue;
            if(sum[ch[p][i]]>=k){
                putchar(i+'a');
                k-=siz[ch[p][i]];p=ch[p][i];break;
            }
            else k-=sum[ch[p][i]];
        }
    }
    putchar('\n');
    return 0;
}

4、SP1811 LCS - Longest Common Substring

题意:找两个串的最长公共子串

类似 \(kmp\) 一样匹配两个串,失配的话一直向上跳

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=500000+10;
int n,a[maxn],c[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],ans;
char s[maxn];

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;ch[p][c]==r;p=fa[p]) ch[p][c]=s; 
		}
	}
}

int main()
{
	scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
	for(int i=1;i<=n;i++) insert(s[i]-'a');
	scanf("%s",s+1);n=strlen(s+1);
	int p=1,c,len=0;
	for(int i=1;i<=n;i++){
		c=s[i]-'a';
		if(ch[p][c]) len++,p=ch[p][c];
		else {
			for(;p&&!ch[p][c];p=fa[p]);
			if(p) len=l[p]+1,p=ch[p][c];
			else len=0,p=1;
		}
		ans=max(ans,len);
	}
	printf("%d\n",ans);
	return 0;
}

5、SP1812 LCS2 - Longest Common Substring II

题意:找多个串的最长公共子串

类似双串 \(LCS\) 一样,就是在匹配的时候记录下来长度,先取个 \(min\),再在 \(min\) 中取个 \(max\)

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=200000+10;
int n,a[maxn],c[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],Max[maxn],Min[maxn],ans;
char s[maxn];

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
}

int main()
{
	scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
	for(int i=1;i<=n;i++) insert(s[i]-'a');
	for(int i=1;i<=cnt;i++) c[l[i]]++;
	for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
	for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
	for(int i=1;i<=cnt;i++) Min[i]=l[i];
	int len,p,c;
	while(scanf("%s",s+1)!=EOF){
		len=0;p=1;n=strlen(s+1);
		for(int i=1;i<=n;i++){
			c=s[i]-'a';
			if(ch[p][c]) p=ch[p][c],Max[p]=max(Max[p],++len);
			else {
				for(;p&&!ch[p][c];p=fa[p]);
				if(p) len=l[p],p=ch[p][c],Max[p]=max(Max[p],++len);
				else p=1,len=0;
			}
		}
		for(int i=cnt;i>=1;i--){
			Min[a[i]]=min(Min[a[i]],Max[a[i]]);
			if(Max[a[i]]&&fa[a[i]]) Max[fa[a[i]]]=l[fa[a[i]]];
			Max[a[i]]=0;
		}
	}
	for(int i=2;i<=cnt;i++) ans=max(ans,Min[i]);
	printf("%d\n",ans);
	return 0;
}

6、[ZJOI2015]诸神眷顾的幻想乡

题意:问多少条本质不同路径(可以翻转)

边搜边建。因为叶子结点只有 \(20\) 个,每次就以度数为 \(1\) 的点开始搜索,建一个广义后缀自动机。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=200000+10;
int n,m,a[maxn],in[maxn],ch[maxn*20][26],fa[maxn*20],l[maxn*20],cnt;
int head[maxn],to[maxn<<1],nxt[maxn<<1],tot;ll ans;

inline int read(){
	register int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
	return (f==1)?x:-x;
}
inline void add(int x,int y){
	to[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}

int insert(int c,int p){
	int q=++cnt;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;p&&ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
	return q;
}

void dfs(int x,int f,int p){
	p=insert(a[x],p);
	for(int i=head[x],y;i;i=nxt[i]){
		y=to[i];
		if(y==f) continue;
		dfs(y,x,p);
	}
}

int main()
{
	n=read(),m=read();cnt=1;
	for(int i=1;i<=n;i++) a[i]=read();
	int x,y;
	for(int i=1;i<n;i++){
		x=read(),y=read();
		add(x,y);add(y,x);
		in[x]++;in[y]++;
	}
	for(int i=1;i<=n;i++)
		if(in[i]==1) dfs(i,0,1);
	for(int i=1;i<=cnt;i++)
		ans+=l[i]-l[fa[i]];
	printf("%lld\n",ans);
	return 0;
}

7、[HEOI2016/TJOI2016]字符串

题意:\(q\) 次询问,问在 \([a,b]\) 的子串与 \([c,d]\) 的最长前缀长度

前缀不好处理,那么就干脆翻转一下好了。

最长前缀不好处理,那么就干脆二分答案一下好了。

二分后转化为存在性问题,那么就干脆套个线段树合并好了。

建一个后缀自动机,\(l\) 的长度就倍增一下。

时间复杂度 \(O(n\log^2 n)\)

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=500000+10;
int n,m,a[maxn],b[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn];
int pos[maxn],f[maxn][21],T[maxn],L[maxn*60],R[maxn*60],sum[maxn*60],tot;
char s[maxn];

inline int read(){
	register int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
	return (f==1)?x:-x;
}

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;p&&ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
}

void pushup(int now){
	sum[now]=sum[L[now]]+sum[R[now]];
}

void update(int &now,int l,int r,int x){
	if(!now) now=++tot;
	if(l == r){sum[now]++;return ;}
	int mid=(l+r)>>1;
	if(x <= mid) update(L[now],l,mid,x);
	else update(R[now],mid+1,r,x);
	pushup(now);
}

int merge(int x,int y,int l,int r){
	if(!x||!y) return x+y;
	if(l == r){sum[x]+=sum[y];return x;}
	int mid=(l+r)>>1,z=++tot;
	L[z]=merge(L[x],L[y],l,mid);
	R[z]=merge(R[x],R[y],mid+1,r);
	pushup(z);
	return z;
}

int query(int now,int Le,int Ri,int l,int r){
	if(!now) return 0;
	if(Le <= l && r <= Ri) return sum[now];
	int mid=(l+r)>>1,ans=0;
	if(Le <= mid) ans+=query(L[now],Le,Ri,l,mid);
	if(Ri > mid) ans+=query(R[now],Le,Ri,mid+1,r);
	return ans;
}

int check(int len,int x,int L,int R){
	for(int i=20;i>=0;i--)
		if(f[x][i]&&l[f[x][i]]>=len) x=f[x][i];
	return query(T[x],L+len-1,R,1,n);
}

int main()
{
	n=read(),m=read();
	scanf("%s",s+1);last=cnt=1;
	reverse(s+1,s+n+1);
	for(int i=1;i<=n;i++){
		insert(s[i]-'a');pos[i]=last;
		update(T[pos[i]],1,n,i);
	}
	for(int i=1;i<=cnt;i++) b[l[i]]++;
	for(int i=1;i<=cnt;i++) b[i]+=b[i-1];
	for(int i=1;i<=cnt;i++) a[b[l[i]]--]=i;
	for(int i=cnt;i>=1;i--)
		if(fa[a[i]]) T[fa[a[i]]]=merge(T[fa[a[i]]],T[a[i]],1,n);
	for(int i=1;i<=cnt;i++) f[i][0]=fa[i];
	for(int j=1;j<=20;j++)
		for(int i=1;i<=cnt;i++) f[i][j]=f[f[i][j-1]][j-1];
	int a,b,c,d,l,r,mid,ans;
	while(m--){
		a=n-read()+1,b=n-read()+1,c=n-read()+1,d=n-read()+1;
		l=0,r=min(a-b+1,c-d+1),ans=0;
		while(l<=r){
			mid=(l+r)>>1;
			if(check(mid,pos[c],b,a)) l=mid+1,ans=mid;
			else r=mid-1;
		}
		printf("%d\n",ans);
	}
	return 0;
}

8、[HAOI2016]找相同字符

题意:问两个串的公共子串个数

如核心代码 \((l[i]-l[fa[i]])\times siz[i]\)

	for(int i=2;i<=cnt;i++) sum[a[i]]=sum[fa[a[i]]]+(ll)(l[a[i]]-l[fa[a[i]]])*siz[a[i]];
	scanf("%s",s+1);n=strlen(s+1);
	int len=0,p=1,c;
	for(int i=1;i<=cnt;i++){
		c=s[i]-'a';
		for(;p&&!ch[p][c];p=fa[p]);
		if(!p) p=1,len=0;
		else {
			len=min(len,l[p])+1;p=ch[p][c];
			ans+=sum[fa[p]]+(ll)(len-l[fa[p]])*siz[p];
		}
	}

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=400000+10;
int n,a[maxn],b[maxn],last,cnt,ch[maxn][26],fa[maxn],l[maxn],siz[maxn];
char s[maxn];ll sum[maxn],ans;

void insert(int c){
	int p=last,q=++cnt;last=q;l[q]=l[p]+1;
	for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=q;
	if(!p) fa[q]=1;
	else {
		int r=ch[p][c];
		if(l[p]+1==l[r]) fa[q]=r;
		else {
			int s=++cnt;l[s]=l[p]+1;
			memcpy(ch[s],ch[r],sizeof(ch[r]));
			fa[s]=fa[r];fa[r]=fa[q]=s;
			for(;p&&ch[p][c]==r;p=fa[p]) ch[p][c]=s;
		}
	}
	siz[q]=1;
}

int main()
{
	scanf("%s",s+1);n=strlen(s+1);last=cnt=1;
	for(int i=1;i<=n;i++) insert(s[i]-'a');
	for(int i=1;i<=cnt;i++) b[l[i]]++;
	for(int i=1;i<=cnt;i++) b[i]+=b[i-1];
	for(int i=cnt;i>=1;i--) a[b[l[i]]--]=i;
	for(int i=cnt;i>=1;i--) siz[fa[a[i]]]+=siz[a[i]];
	for(int i=2;i<=cnt;i++) sum[a[i]]=sum[fa[a[i]]]+(ll)(l[a[i]]-l[fa[a[i]]])*siz[a[i]];
	scanf("%s",s+1);n=strlen(s+1);
	int len=0,p=1,c;
	for(int i=1;i<=cnt;i++){
		c=s[i]-'a';
		for(;p&&!ch[p][c];p=fa[p]);
		if(!p) p=1,len=0;
		else {
			len=min(len,l[p])+1;p=ch[p][c];
			ans+=sum[fa[p]]+(ll)(len-l[fa[p]])*siz[p];
		}
	}
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-12-05 19:05  Owen_codeisking  阅读(229)  评论(0编辑  收藏  举报