后缀自动机
关于SAM的介绍和构建见这几篇博客,这里主要是SAM的应用以及题目:
OI-wiki
洛谷日报
应用:
1.求一个串出现次数
模板题
利用parent tree的性质,将每个叶子(其实就是所有前缀)的size设为1,一个点内所有串的出现次数即为子树内size大小(即叶子个数)。
正确性:
显然出现次数即为终点集合大小,每个叶子节点对应唯一一个终点(因为是一个前缀)。
code:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+10;
int n,ans,tot,last,cnt;
int head[maxn<<1],size[maxn<<1];
char s[maxn];
struct edge{int to,nxt;}e[maxn<<2];
inline void edge_add(int u,int v)
{
e[++cnt].nxt=head[u];
head[u]=cnt;
e[cnt].to=v;
}
struct Sam
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline void sam_init(){sam[0].len=0;sam[0].fa=-1;last=0;}
inline void sam_add(int c)
{
int now=++tot;size[now]=1;
sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa;sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
void dfs(int x)
{
for(int i=head[x];i;i=e[i].nxt)
dfs(e[i].to),size[x]+=size[e[i].to];
if(size[x]>1)ans=max(ans,sam[x].len*size[x]);
}
int main()
{
scanf("%s",s+1);n=strlen(s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
for(int i=1;i<=tot;i++)edge_add(sam[i].fa,i);
dfs(0);
printf("%d",ans);
return 0;
}
2.求一个串是否出现过
因为所有子串会在SAM被唯一表示,因此沿着SAM上的边走,发现失配即不存在。
3.求一个串不同子串的个数
模板题
有两种做法:
<1>利用parent tree的性质
对于每个非空集的节点\(i\)求\(sam[i].len-sam[sam[i].fa].len\),加起来就是答案。
正确性:
因为SAM上没有重复的字符串,所有状态的字符串加起来就是答案,又因为一个集合的字串长度是连续的,于是可以通过\(len\)相减得到。
于是\(\sum\limits_{i=1}^{tot}sam[i].len-sam[sam[i].fa].len\)就是答案
code:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1e5+10;
int n,ans,last,tot;
char s[maxn];
struct Sam
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline void sam_init(){sam[0].len=0,sam[0].fa=-1;last=0;}
inline void sam_add(int c)
{
int now=++tot;
sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa;sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
signed main()
{
scanf("%lld%s",&n,s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
for(int i=1;i<=tot;i++)ans+=sam[i].len-sam[sam[i].fa].len;
printf("%lld",ans);
return 0;
}
同理:P4070 [SDOI2016]生成魔咒
<2>DAG上DP
注意到答案即为从空集点(0)出发的路径条数,又因为SAM是个DAG图,因此可以DP。
设\(f_x\)表示从\(x\)出发的路径条数,显然\(f_0-1\)就是答案,\(-1\)是为了减去空串。
转移有:
\(f_x=1+\sum\limits_{ \exists\ edge(x,y)}f_y\)
code:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1e5+10;
int n,ans,last,tot;
int a[maxn<<1],f[maxn<<1];
char s[maxn];
struct Sam
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline bool cmp(int x,int y){return sam[x].len>sam[y].len;}
inline void sam_init(){sam[0].len=0,sam[0].fa=-1;last=0;}
inline void sam_add(int c)
{
int now=++tot;
sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa;sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
signed main()
{
scanf("%lld%s",&n,s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
for(int i=0;i<=tot;i++)a[i]=i;
sort(a,a+tot+1,cmp);
for(int i=0;i<=tot;i++)
{
f[a[i]]=1;
for(int j=0;j<26;j++)
if(sam[a[i]].ch[j])f[a[i]]+=f[sam[a[i]].ch[j]];
}
printf("%lld",f[0]-1);
return 0;
}
4.求不同串的长度和
依然有两种做法:
<1>同3.,可以DP求出。
设\(f_i\)表示\(i\)的出发的路径条数,\(g_i\)表示从\(i\)出发的路径总长度。
\(f_i\)的转移在3.中给出,\(g_i\)的转移如下:
\(g_x=\sum\limits_{ \exists\ edge(x,y)}f_y+g_y\)
<2>同3.,利用parent tree的性质
每个节点\(i\)对应的的后缀总长是\(\frac{len_i(len_i+1)}{2}\)(等差数列求和),减去父亲节点的该值即为当前节点的答案,求和即可。
5.求字典序第k大子串
先求出\(f_i\)表示从\(i\)出发的串个数,用类似平衡树上二分的方法在SAM上跑即可。
code:
#include<bits/stdc++.h>
using namespace std;
const int maxn=90010;
int T,n,tot,last,cnt;
int head[maxn<<1],f[maxn<<1],in[maxn<<1];
char s[maxn];
struct edge{int to,nxt;}e[maxn<<2];
struct Sam
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline bool cmp(int x,int y){return sam[x].len>sam[y].len;}
inline void add(int u,int v)
{
e[++cnt].nxt=head[u];
head[u]=cnt;
e[cnt].to=v;
in[v]++;
}
inline void sam_init(){sam[0].fa=-1,sam[0].len=0;last=0;}
inline void sam_add(int c)
{
int now=++tot;
sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa;sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
inline void topsort()
{
queue<int>q;
for(int i=0;i<=tot;i++)if(!in[i])q.push(i);
while(!q.empty())
{
int x=q.front();q.pop();
f[x]=1;
for(int i=0;i<26;i++)if(sam[x].ch[i])f[x]+=f[sam[x].ch[i]];
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(!(--in[y]))q.push(y);
}
}
f[0]--;
}
inline void solve(int k)
{
int now=0;
while(k)
{
for(int i=0;i<26;i++)
{
if(!sam[now].ch[i])continue;
if(f[sam[now].ch[i]]<k)k-=f[sam[now].ch[i]];
else
{
putchar(i+'a');
now=sam[now].ch[i];
k--;break;
}
}
}
}
int main()
{
scanf("%s",s+1);n=strlen(s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
for(int i=0;i<=tot;i++)
for(int j=0;j<26;j++)
if(sam[i].ch[j])add(sam[i].ch[j],i);
topsort();
scanf("%d",&T);
while(T--)
{
int k;scanf("%d",&k);
solve(k);puts("");
}
return 0;
}
扩展到本质相同的子串:P3975 [TJOI2015]弦论
5.最小循环移位
复制一份拆入SAM中,就变为找最小的长为n的子串,贪心即可。
6.第一次出现的位置
即求出每个状态的endpos中最小的那个。
对每个状态\(s\)维护\(firstpos(s)\),表示s的endpos中最小的那个。
code(已对拍):
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
int n,m,last,tot;
char s[maxn],t[maxn];
struct Sam
{
int fa,len,firpos;
int ch[26];
}sam[maxn<<1];
inline void sam_init(){sam[0].fa=-1,sam[0].len=0;last=0;}
inline void sam_add(int c)
{
int now=++tot;sam[now].len=sam[last].len+1;sam[now].firpos=sam[now].len;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;sam[nowq].firpos=sam[q].firpos;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa;sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
inline int query(char* t)
{
int len=strlen(t+1);
int now=0;
for(int i=1;i<=len;i++)
{
int c=t[i]-'a';
now=sam[now].ch[c];
}
return sam[now].firpos;
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
scanf("%s",s+1);n=strlen(s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
scanf("%d",&m);
while(m--)
{
scanf("%s",t+1);
int len=strlen(t+1),res;
res=query(t);
if(~res)printf("%d\n",res-len+1);
else puts("-1");
}
return 0;
}
7.一个串的所有出现位置
建出parent树,遍历该串的子树,遇见遇叶子节点就输出。
8.最短未出现子串
注意字符集给定,空串已出现过。
显然答案是从源点走到一个没有出边的节点再随便选个字符接上,于是我们要求的其实是从源点到一个最近的没有出边的节点的距离+1。
设\(f_x\)表示从x到最近的没有出边的节点的距离+1,显然有:
\(f_x=1+\min_{\exists (x,y)}f_y\)
\(f_0\)即为答案,输出只需要通过\(f\)退回去即可。
9.两字符串最长公共子串
考虑线对一个串建出SAM,之后求另一个串的每一个前缀与第一个串能匹配的最长后缀\(l_i\),显然\(\max(l_i)\)即为答案。
匹配的过程:
设当前匹配到第i个前缀,当前节点为\(now\),已经匹配的长度为\(nowl\)。
如果\(now\)有\(s1_i\)这条出边,就令\(now=sam[now].ch[s1_i],nowl++\)
否则就一直跳\(now\)的\(fa\)(即遍历\(now\)的所有后缀),同时令\(nowl=len_{now}\),直到匹配或者\(now=0\)
模板题
code:
#include<bits/stdc++.h>
using namespace std;
const int maxn=250010;
int n,m,last,tot;
char s1[maxn],s2[maxn];
struct Sam
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline void sam_init(){sam[0].fa=-1,sam[0].len=0;last=0;}
inline void sam_add(int c)
{
int now=++tot;sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa,sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
inline int query(char* s,int len)
{
int res=0,now=0,nowl=0;
for(int i=1;i<=len;i++)
{
int c=s[i]-'a';
while(now&&!sam[now].ch[c])now=sam[now].fa,nowl=sam[now].len;
if(sam[now].ch[c])now=sam[now].ch[c],nowl++;
res=max(res,nowl);
}
return res;
}
int main()
{
scanf("%s%s",s1+1,s2+1);
n=strlen(s1+1),m=strlen(s2+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s1[i]-'a');
printf("%d\n",query(s2,m));
return 0;
}
10.多个串的最长公共子串
OI_wiki上的做法没看懂。
考虑扩展下9.的做法:
先建出第一个串的sam,之后让每个串和它匹配。
考虑对每个点\(x\)记如下信息:
\(minn_x\)表示\(x\)节点的最长串与所有串匹配的最小长度。
\(maxx_x\)表示在和某一个串\(s_i\)(注意这是一个在匹配过程中使用的,并不是全局的)匹配时,\(x\)这个节点的最长串能和\(s_i\)匹配的最长长度。
我们最后只要对所有节点的\(minn\)求个\(max\)即为答案。
当我们和\(s_i\)用9.的方法匹配后,我们要注意每个点的\(maxx\)并不一定满足它的定义,因为它在parent树上的儿子匹配的长度可能大于它,因为儿子能匹配,所以它肯定也能匹配,因此最后\(maxx_x\)要和它的子树取\(max\)(注意上界是自己的长度)。
code:
#include<bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,m,tot,last,ans=0;
int a[maxn<<1],b[maxn<<1],maxx[maxn<<1],minn[maxn<<1];
char s[maxn];
struct SAM
{
int fa,len;
int ch[26];
}sam[maxn<<1];
inline void sam_init(){sam[0].fa=-1;sam[0].len=0;last=0;}
inline void sam_add(int c)
{
int now=++tot;sam[now].len=sam[last].len+1;
int p=last;
while(~p&&!sam[p].ch[c])sam[p].ch[c]=now,p=sam[p].fa;
if(p==-1){sam[now].fa=0;last=now;return;}
int q=sam[p].ch[c];
if(sam[q].len==sam[p].len+1)sam[now].fa=q;
else
{
int nowq=++tot;
sam[nowq].len=sam[p].len+1;
memcpy(sam[nowq].ch,sam[q].ch,sizeof(sam[q].ch));
sam[nowq].fa=sam[q].fa,sam[q].fa=sam[now].fa=nowq;
while(~p&&sam[p].ch[c]==q)sam[p].ch[c]=nowq,p=sam[p].fa;
}
last=now;
}
inline void work(char* s)
{
int len=strlen(s+1),now=0,nowl=0;
for(int i=1;i<=len;i++)
{
int c=s[i]-'a';
while(now&&!sam[now].ch[c])now=sam[now].fa,nowl=sam[now].len;
if(sam[now].ch[c])now=sam[now].ch[c],nowl++;
maxx[now]=max(maxx[now],nowl);
}
for(int i=tot+1;i;i--)
{
int x=a[i];
if(~sam[x].fa)maxx[sam[x].fa]=max(maxx[sam[x].fa],min(maxx[x],sam[sam[x].fa].len));
minn[x]=min(minn[x],maxx[x]);maxx[x]=0;
}
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
scanf("%s",s+1);n=strlen(s+1);
sam_init();
for(int i=1;i<=n;i++)sam_add(s[i]-'a');
for(int i=0;i<=tot;i++)b[sam[i].len]++;
for(int i=1;i<=n;i++)b[i]+=b[i-1];
for(int i=0;i<=tot;i++)a[b[sam[i].len]--]=i;
//for(int i=0;i<=tot;i++)cerr<<i<<' '<<sam[i].fa<<endl;
memset(minn,0x3f,sizeof(minn));
while(~scanf("%s",s+1))work(s);
for(int i=0;i<=tot;i++)ans=max(ans,minn[i]);
printf("%d",ans);
return 0;
}