hihocoder后缀自动机练习
求字符串子串种类数
题:http://hihocoder.com/problemset/problem/1445
分析:后缀自动机模板
#include<iostream> #include<algorithm> #include<cstring> #include<cstdio> using namespace std; typedef long long ll; const int M=1e6+3; int trans[M<<1][26],slink[M<<1],maxlen[M<<1]; int endpos[M<<1]; int last,now,root; void init(){ now=last=root=1; memset(trans,0,sizeof(trans)); memset(slink,0,sizeof(slink)); memset(maxlen,0,sizeof(maxlen)); } void extend(int c){ maxlen[++now]=maxlen[last]+1; int p=last,np=now; while(p&&!trans[p][c]){ trans[p][c]=np; p=slink[p]; } ///在之前构造的sam中出现了现在的后缀 if(!p) slink[np]=root; else{ int q=trans[p][c]; if(maxlen[p]+1!=maxlen[q]){///若p+c不是q中最大的字符串, ///新建个克隆节点,把p+c从q中挑出来 maxlen[++now]=maxlen[p]+1; int nq=now; memcpy(trans[nq],trans[q],sizeof(trans[q])); slink[nq]=slink[q]; slink[q]=slink[np]=nq; while(p&&trans[p][c]==q){ trans[p][c]=nq; p=slink[p]; } } else///否则np直接link连接q slink[np]=q; } last=np; endpos[np]=1; } char s[M]; int main(){ scanf("%s",s+1); init(); int n=strlen(s+1); for(int i=1;i<=n;i++) extend(s[i]-'a'); ll ans=0; for(int i=root+1;i<=now;i++) ans+=maxlen[i]-maxlen[slink[i]]; printf("%lld",ans); return 0; }
求1~i长度的字符串最多出现的次数
题: https://hihocoder.com/problemset/problem/1449
分析:
- 核心点:求出每个状态的endpos大小|endpos|将其设为f数组。这个依据slink来增加到父亲节点去,初始化:只有maxlen的串为主串的前缀时才为1
#include<iostream> #include<algorithm> #include<cstring> #include<cstdio> using namespace std; typedef long long ll; const int M=1e6+3; int trans[M<<1][26],slink[M<<1],maxlen[M<<1]; int endpos[M<<1]; ll ans[M],f[M<<1]; int last,tot,root; void init(){ tot=last=root=1; memset(trans,0,sizeof(trans)); memset(slink,0,sizeof(slink)); memset(maxlen,0,sizeof(maxlen)); } void extend(int c){ maxlen[++tot]=maxlen[last]+1; int p=last,np=tot; f[np] = 1;///maxlen的串为主串的前缀时才为1 while(p&&!trans[p][c]){ trans[p][c]=np; p=slink[p]; } ///在之前构造的sam中出现了现在的后缀 if(!p) slink[np]=root; else{ int q=trans[p][c]; if(maxlen[p]+1!=maxlen[q]){///若p+c不是q中最大的字符串, ///新建个克隆节点,把p+c从q中挑出来 maxlen[++tot]=maxlen[p]+1; int nq=tot; memcpy(trans[nq],trans[q],sizeof(trans[q])); slink[nq]=slink[q]; slink[q]=slink[np]=nq; while(p&&trans[p][c]==q){ trans[p][c]=nq; p=slink[p]; } } else///否则np直接link连接q slink[np]=q; } last=np; endpos[np]=1; } char s[M]; int que[M<<1],in[M<<1]; void tuopu(){ int l=1,r=0; for(int i=1;i<=tot;i++)in[slink[i]]++; for(int i=1;i<=tot;i++)if(!in[i]) que[++r]=i; while(l<=r){ int x=que[l++]; f[slink[x]]+=f[x]; if(--in[slink[x]]==0) que[++r]=slink[x]; } } int main(){ scanf("%s",s+1); init(); int n=strlen(s+1); for(int i=1;i<=n;i++) extend(s[i]-'a'); tuopu(); for(int i=1;i<=tot;i++) ans[maxlen[i]]=max(ans[maxlen[i]],f[i]); for(int i=n;i;i--)///因为可能状态中的maxlen没有能表示某些小的子串,而maxlen可以理解为伴随着一些子串的存在,所以用大的去更新小的 ans[i]=max(ans[i],ans[i+1]); for(int i=1;i<=n;i++) printf("%lld\n",ans[i]); return 0; }
求多串0~9子串之和(允许前导0)
题:https://hihocoder.com/problemset/problem/1457?sid=1596892
分析:用特殊符号拼接起来,然后一大串去建立sam,然后在trans上跑拓扑序,不跑特殊符号的,然后就可以进行*10转移
#include<iostream> #include<algorithm> #include<cstring> #include<cstdio> #include<queue> using namespace std; typedef long long ll; const int M=1e6+3; const int mod=1e9+7; int trans[M<<1][11],slink[M<<1]; ll maxlen[M<<1]; ll f[M<<1],cnt[M<<1]; int last,tot,root; char s[M<<1]; int in[M<<1]; ll res=0; void init(){ tot=last=root=1; memset(in,0,sizeof(in)); memset(f,0,sizeof(f)); memset(trans,0,sizeof(trans)); memset(slink,0,sizeof(slink)); memset(maxlen,0,sizeof(maxlen)); } void extend(int c){ maxlen[++tot]=maxlen[last]+1; int p=last,np=tot; while(p&&!trans[p][c]){ trans[p][c]=np; p=slink[p]; } if(!p) slink[np]=root; else{ int q=trans[p][c]; if(maxlen[p]+1!=maxlen[q]){ maxlen[++tot]=maxlen[p]+1; int nq=tot; memcpy(trans[nq],trans[q],sizeof(trans[q])); slink[nq]=slink[q]; slink[q]=slink[np]=nq; while(p&&trans[p][c]==q){ trans[p][c]=nq; p=slink[p]; } } else slink[np]=q; } last=np; } int vis[M<<1]; void tuopu(){ queue<int>que; while(!que.empty())que.pop(); que.push(root); while(!que.empty()){ int u=que.front(); que.pop(); for(int i=0;i<10;i++){ int v=trans[u][i]; if(!v) continue; ++in[v]; if(!vis[v]) que.push(v); vis[v]=1; } } que.push(1); cnt[1]=1; while(!que.empty()){ int u=que.front(); que.pop(); for(int i=0;i<10;i++){ int v=trans[u][i]; if(!v)continue; cnt[v]+=cnt[u]; f[v]=(f[v]+(f[u]*10+i*cnt[u])%mod)%mod; if(--in[v]==0) que.push(v); } } for(int i=1;i<=tot;i++) res=(res+f[i])%mod; } int main(){ init(); int n; scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%s",s+1); int m=strlen(s+1); for(int j=1;j<=m;j++) extend(s[j]-'0'); if(i!=n) extend(':'-'0'); } tuopu(); printf("%lld\n",res); return 0; }