[基本操作]后缀自动机
来介绍一些基本操作
首先,介绍一下 Suffix Automaton
后缀自动机大概由两部分组成—— DAWG 和 Parent Tree
1.DAWG
DAWG 的中文名字叫做“单词的有向无环图”
它由一个初始节点 init ,若干条转移边,若干个节点组成
DAWG 表示的是状态的转移关系,我们可以记一个点能识别的终止位置集合为 $end-pos(i)$,每个点的子串是一个前缀的一些后缀,这些后缀的长度都在 [minlen,maxlen] 这个区间里
2.Parent Tree
Parent Tree 类似 AC 自动机的 fail 树,是由 $end-pos$ 集合的包含关系构成的一棵树,满足 fa[i] 的 maxlen + 1 等于 i 的 minlen
由这个我们可以知道对 DAWG 拓扑排序相当于对 maxlen 数组快速排序/基数排序,由这个我们也可以知道其实并不用记录每个点的 minlen
Parent Tree 是反串的后缀树
由此我们可以做题
bzoj3879 SvT
给你一个串和若干组询问,每组询问包括若干个后缀,你要求出这些后缀两两间最长公共前缀长度的和
sol:后缀 i 和后缀 j 的 lcp 相当于后缀树上 i 的位置和 j 的位置的 LCA 深度
我们把串反过来,然后建 SAM
然后我们虚树 + 树形 dp 就可以了
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-')f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } const int maxn = 1200010; const LL mod = 23333333333333333LL; int n,a[maxn],pos[maxn],rnk[maxn]; int tr[maxn][26]; int fa[maxn],len[maxn],dfn,root,last; char s[maxn]; void extend(int c) { int p = last,np = last = ++dfn; len[np] = len[p] + 1; while(p && !tr[p][c])tr[p][c] = np,p = fa[p]; if(!p)fa[np] = root; else { int q = tr[p][c]; if(len[q] == len[p] + 1)fa[np] = q; else { int nq = ++dfn; len[nq] = len[p] + 1;memcpy(tr[nq],tr[q],sizeof(tr[nq]));fa[nq] = fa[q],fa[np] = fa[q] = nq; while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p]; } } } int first[maxn],to[maxn],nx[maxn],cnt; LL ans; int val[maxn]; inline void add(int u,int v){to[++cnt] = v;nx[cnt] = first[u];first[u] = cnt;} inline void ins(int u,int v){add(u,v);add(v,u);} int size[maxn],dep[maxn],ff[maxn],bl[maxn],ind[maxn],_tim; inline void dfs1(int x) { size[x] = 1;ind[x] = ++_tim; for(int i=first[x];i;i=nx[i]) { if(to[i] == ff[x])continue; ff[to[i]] = x; dep[to[i]] = dep[x] + 1; dfs1(to[i]); size[x] += size[to[i]]; } } inline void dfs2(int x,int col) { bl[x] = col; int k = 0; for(int i=first[x];i;i=nx[i]) if(to[i] != ff[x] && size[to[i]] > size[k])k = to[i]; if(!k)return; dfs2(k,col); for(int i=first[x];i;i=nx[i]) if(to[i] != ff[x] && to[i] != k)dfs2(to[i],to[i]); } inline int lca(int x,int y) { while(bl[x] != bl[y]) { if(dep[bl[x]] < dep[bl[y]])swap(x,y); x = ff[bl[x]]; }return dep[y] < dep[x] ? y : x; } inline bool cmp(const int &x,const int &y){return ind[x] < ind[y];} int stk[maxn],f[maxn]; inline void dp(int x) { f[x] = val[x] ? 1 : 0; for(int i=first[x];i;i=nx[i]) { dp(to[i]); ans += (LL)f[x] * f[to[i]] * len[x]; f[x] += f[to[i]]; } first[x] = 0; //cout<<x<<endl; } int main() { #ifdef Ez3real freopen("ww.in","r",stdin); #endif root = last = ++dfn; n = read();int q = read();scanf("%s",s + 1); reverse(s + 1,s + n + 1); for(int i=1;i<=n;i++)extend(s[i] - 'a'),pos[n - i + 1] = last; for(int i=1;i<=dfn;i++)add(fa[i],i);dfs1(root);dfs2(root,root); memset(first,0,sizeof(first)); while(q--) { cnt = 0; int k = read();//cout<<k<<"!!"<<endl; for(int i=1;i<=k;i++)a[i] = pos[read()]; sort(a + 1,a + k + 1,cmp); int nn = 0;a[++nn] = a[1]; for(int i=2;i<=k;i++) if(a[i] != a[i - 1])a[++nn] = a[i]; for(int i=1;i<=nn;i++)val[a[i]] = 1; int top = 0; for(int i=1;i<=nn;i++) { if(!top){stk[++top] = a[i];continue;} int x = a[i],l = lca(x,stk[top]); while(ind[l] < ind[stk[top]]) { if(ind[l] >= ind[stk[top - 1]]) { add(l,stk[top--]); if(l != stk[top])stk[++top] = l; break; }else add(stk[top - 1],stk[top]),top--; } stk[++top] = x; } while(top > 1)add(stk[top - 1],stk[top]),top--; ans = 0; dp(stk[1]); printf("%lld\n",ans); for(int i=1;i<=nn;i++)val[a[i]] = 0; } }
upd:把后缀数组里那几个地方拿出来建一个“虚后缀数组”,求出“虚 height ”
然后单调栈就可以了。。。
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x=0,f=1;char ch; for(ch=getchar();!isdigit(ch);ch=getchar())if(ch=='-')f=-f; for(;isdigit(ch);ch=getchar())x=10*x+ch-'0'; return x*f; } const int maxn = 1e6 + 10; const LL mod = 23333333333333333LL; int n,m; char s[maxn];LL ans; #define equ(x) (y[sa[i] + x] == y[sa[i - 1] + x]) int rnk[maxn],tmp[maxn],sa[maxn],hei[maxn]; int x[maxn],y[maxn],wc[maxn]; void radix_sort(int m) { for(int i=1;i<=m;i++)wc[i] = 0; for(int i=1;i<=n;i++)wc[x[y[i]]]++; for(int i=1;i<=m;i++)wc[i] += wc[i - 1]; for(int i=n;i>=1;i--)sa[wc[x[y[i]]]--] = y[i]; } void makesa(char *s,int n,int m) { for(int i=1;i<=n;i++)x[i] = s[i],y[i] = i;radix_sort(m); for(int j=1,p=0;j<=n;j<<=1,m = p,p = 0) { for(int i=n-j+1;i<=n;i++)y[++p] = i; for(int i=1;i<=n;i++) if(sa[i] > j)y[++p] = sa[i] - j; radix_sort(m);swap(x,y);x[sa[1]] = p = 1; for(int i=2;i<=n;i++)x[sa[i]] = equ(0) && equ(j) ? p : ++p; if(p == n)break; } for(int i=1;i<=n;i++)rnk[sa[i]] = i; for(int i=1,k=0;i<=n;hei[rnk[i++]]=k) for(k ? k-- : 0;i+k<=n && s[i+k] == s[sa[rnk[i]-1]+k];k++); } int st[maxn][25],lg[maxn]; void initST() { lg[0] = -1;lg[1] = 0; for(int i=2;i<=n;i++)lg[i] = lg[i >> 1] + 1; for(int i=1;i<=n;i++)st[i][0] = hei[i]; for(int i=1;i<=lg[n];i++) for(int j=1;j + (1 << i) - 1<=n;j++)st[j][i] = min(st[j][i - 1],st[j + (1 << i - 1)][i - 1]); } inline int lcp(int u,int v) { if(u == v)return n - u + 1; u = rnk[u],v = rnk[v]; if(u > v)swap(u,v);u++; int loog = lg[v - u + 1]; return min(st[u][loog],st[v - (1 << loog) + 1][loog]); } inline int rmq(int u,int v) { int loog = lg[v - u + 1]; return min(st[u][loog],st[v - (1 << loog) + 1][loog]); } int k,q[maxn],stk[maxn],tot[maxn]; int main() { n = read(),m = read(); scanf("%s",s + 1); makesa(s,n,200);initST(); while(m--) { k = read();ans = 0; for(int i=1;i<=k;i++)q[i] = rnk[read()]; sort(q + 1,q + k + 1); k = unique(q + 1,q + k + 1) - q - 1; int top = 0;LL sum = 0,cur,cnt; for(int i=2;i<=k;i++) { cur = rmq(q[i - 1] + 1,q[i]),cnt = 0; while(top && cur <= stk[top]) { cnt += tot[top]; sum = ((sum - (LL)stk[top] * (LL)tot[top])%mod + mod) % mod; top--; } stk[++top] = cur;tot[top] = cnt + 1; sum = ((sum + (LL)stk[top] * (LL)tot[top])% mod + mod) % mod; (ans += sum) %= mod; } ans = ((ans % mod) + mod) % mod; cout<<ans<<endl; } }
bzoj2882 工艺
求字符串的最小表示法,也就是说,把字符串组成一个环,从任意一个位置开始读一圈,求读出来的字符串字典序最小的方案
sol:环 -> 二倍链
然后从 init 开始走,每次走字典序最小的那个转移边,走 n 步就是最小表示法
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-')f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } const int maxn = 1600010; int n,a[maxn]; map<int,int> tr[maxn]; int fa[maxn],len[maxn],dfn,root,last; void extend(int c) { int p = last,np = last = ++dfn; len[np] = len[p] + 1; while(p && !tr[p][c])tr[p][c] = np,p = fa[p]; if(!p)fa[np] = root; else { int q = tr[p][c]; if(len[q] == len[p] + 1)fa[np] = q; else { int nq = ++dfn; len[nq] = len[p] + 1,tr[nq] = tr[q],fa[nq] = fa[q],fa[np] = fa[q] = nq; while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p]; } } } int main() { root = last = ++dfn; n = read();map<int,int>::iterator it; for(int i=1;i<=n;i++)a[n + i] = a[i] = read(); int kn = n + n,p = root; for(int i=1;i<kn;i++)extend(a[i]); while(n--) { it = tr[p].begin(); printf("%d",it -> first); if(n)putchar(' '); p = it -> second; } }
bzoj4516 生成魔咒
一开始有一个空串,每次加入一个字符,询问当前本质不同的子串数量
sol:一个串本质不同的子串数量就是 $\sum maxlen_i - maxlen_{fa[i]}$
因为每次只会加一个字符,每次只要维护增量即可,也就是每次把新的那个 np 对答案的贡献加进去
#include<bits/stdc++.h> #define LL long long const int maxn = 200050; using namespace std; map<int,int> tr[maxn]; int val[maxn],fa[maxn]; int n,m; LL LastAns; struct SAM { int SIZE,last,root; SAM(){SIZE = last = root = 1;} inline int cal(int x){return val[x] - val[fa[x]];} inline void extend(int x) { int p = last,np = last = ++SIZE; val[np] = val[p] + 1; while(p && !tr[p][x]) tr[p][x] = np,p = fa[p]; if(!p)fa[np] = root, LastAns += cal(np); else { int q = tr[p][x]; if(val[p] + 1 == val[q]){fa[np] = q;LastAns += cal(np);} else { int nq = ++SIZE; val[nq] = val[p] + 1;tr[nq] = tr[q]; fa[nq] = fa[q];LastAns += cal(nq) - cal(q); fa[np] = fa[q] = nq;LastAns += cal(np) + cal(q); while(p && tr[p][x] == q)tr[p][x] = nq, p = fa[p]; } } } }S; int main() { scanf("%d",&n);int x; for(int i=1;i<=n;i++) { scanf("%d",&x); S.extend(x); printf("%lld\n",LastAns); } }
更新
bzoj4199 品酒大会
给一个字符串 S ,每个位置都有一个权值 $w_i$ ,对每一个 $i ∈ [1,n]$ ,求出 $lcp(a,b) = i$ 的后缀数量和 $w_a \times w_b$ 的最大值
sol:lcp -> 后缀树上 lca
第一问就是枚举一下 lca 然后枚举 lca 的相邻子节点 dp 一下就可以了
第二问记一下每个点的最大最小然后乘一下
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } void fre() { freopen("mydata.in","r",stdin); freopen("mydata.out","w",stdout); } const int maxn = 600010; int n,a[maxn]; char s[maxn]; int tr[maxn][26],mxlen[maxn],fa[maxn],dfn,last,root; int val[maxn]; LL cnt[maxn],ans[maxn],mn[maxn],mx[maxn]; int size[maxn]; void extend(int c,int v) { int p = last,np = last = ++dfn; mxlen[np] = mxlen[p] + 1; size[np]++;mx[np] = mn[np] = v; for(;p && !tr[p][c];p = fa[p])tr[p][c] = np; if(!p)fa[np] = root; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)fa[np] = q; else { int nq = ++dfn; mxlen[nq] = mxlen[p] + 1; memcpy(tr[nq],tr[q],sizeof(tr[nq])); fa[nq] = fa[q]; fa[np] = fa[q] = nq; for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq; } } } int first[maxn],to[maxn << 1],nx[maxn << 1],cwt; inline void add(int u,int v) { to[++cwt] = v; nx[cwt] = first[u]; first[u] = cwt; } void dfs(int x) { if(!mx[x] && !mn[x])mx[x] = -1e16,mn[x] = 1e16; for(int i=first[x];i;i=nx[i]) { dfs(to[i]); if(mx[x] != -1e16 && mn[x] != 1e16 && mx[to[i]] != -1e16 && mn[to[i]] != 1e16) ans[mxlen[x]] = max(ans[mxlen[x]],max(mx[x] * mx[to[i]],mn[x] * mn[to[i]])); cnt[mxlen[x]] += 1ll * size[x] * size[to[i]];size[x] += size[to[i]]; mx[x] = max(mx[x],mx[to[i]]);mn[x] = min(mn[x],mn[to[i]]); } } int main() { #ifdef Ez3real fre(); #endif root = last = ++dfn; n = read();scanf("%s",s + 1); //reverse(s + 1,s + n + 1); for(int i=1;i<=n;i++) { a[i] = read(); //extend(s[i] - 'a',a[i]); } for(int i=n;i>=1;i--)extend(s[i] - 'a',a[i]); for(int i=2;i<=dfn;i++)add(fa[i],i); //for(int i=0;i<=n;i++)ans[i] = -1e16; memset(ans,-63,sizeof(ans)); dfs(1); for(int i=n-1;i>=0;i--)cnt[i] += cnt[i + 1],ans[i] = max(ans[i],ans[i + 1]); for(int i=0;i<n;i++) { if(cnt[i]) printf("%lld %lld\n",cnt[i],ans[i]); else puts("0 0"); } }
bzoj4566 找相同字符
给两个字符串 $S_1,S_2$ 求他们有多少个不同的相同子串,两个子串有一个字符位置不同就算不同
sol:广义后缀自动机,记一下第一个串到过多少点,第二个串到过多少点,如果两个串都到过一个点,答案就加上这个点的子串个数
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } void fre() { freopen("mydata.in","r",stdin); freopen("mydata.out","w",stdout); } const int maxn = 800010; char s1[maxn],s2[maxn]; int n,m; int tr[maxn][26],fa[maxn],mxlen[maxn],root,last,dfn; int c1[maxn],c2[maxn]; void extend(int c) { int p = last,np = last = ++dfn; mxlen[np] = mxlen[p] + 1; for(;p && !tr[p][c];p = fa[p])tr[p][c] = np; if(!p)fa[np] = root; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)fa[np] = q; else { int nq = ++dfn; mxlen[nq] = mxlen[p] + 1; memcpy(tr[nq],tr[q],sizeof(tr[nq])); fa[nq] = fa[q]; fa[q] = fa[np] = nq; for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq; } } } int c[maxn],rk[maxn]; void getsize() { for(int i=1;i<=dfn;i++)++c[mxlen[i]]; for(int i=1;i<=n;i++)c[i] += c[i - 1]; for(int i=1;i<=dfn;i++)rk[c[mxlen[i]]--] = i; for(int i=dfn;i>=1;i--) { c1[fa[rk[i]]] += c1[rk[i]]; c2[fa[rk[i]]] += c2[rk[i]]; } } int main() { #ifdef Ez3real fre(); #endif root = last = ++dfn; scanf("%s%s",s1 + 1,s2 + 1); n = strlen(s1 + 1);m = strlen(s2 + 1); for(int i=1;i<=n;i++)extend(s1[i] - 'a'); last = root; for(int i=1;i<=m;i++)extend(s2[i] - 'a'); int now = root; for(int i=1;i<=n;i++) { int p = s1[i] - 'a'; now = tr[now][p]; c1[now]++; } now = root; for(int i=1;i<=m;i++) { int p = s2[i] - 'a'; now = tr[now][p]; c2[now]++; }getsize(); LL ans = 0; for(int i=1;i<=dfn;i++){ans += 1ll * c1[i] * c2[i] * (mxlen[i] - mxlen[fa[i]]);} cout<<ans; }
bzoj3998 弦论
对于一个给定长度为 N 的字符串,求它的第 K 小子串是什么。
不同位置的相同子串可以算多个,也可以算一个
sol:后缀自动机,每个点搞一个权值,如果算多个,就是这个点 end-pos 集合大小,算一个,就是 1
然后按字典序搜一下就可以了
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } void fre() { freopen("mydata.in","r",stdin); freopen("mydata.out","w",stdout); } int n,t,k; const int maxn = 1e6 + 10; char s[maxn]; int tr[maxn][26],fa[maxn],mxlen[maxn],root,last,dfn; int c[maxn],size[maxn],rk[maxn],sum[maxn]; void extend(int c) { int p = last,np = last = ++dfn; mxlen[np] = mxlen[p] + 1;size[np] = 1; for(;p && !tr[p][c];p = fa[p])tr[p][c] = np; if(!p)fa[np] = root; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)fa[np] = q; else { int nq = ++dfn; mxlen[nq] = mxlen[p] + 1; fa[nq] = fa[q]; memcpy(tr[nq],tr[q],sizeof(tr[q])); fa[np] = fa[q] = nq; for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq; } } } void build() { for(int i=1;i<=dfn;i++)++c[mxlen[i]]; for(int i=1;i<=n;i++)c[i] += c[i - 1]; for(int i=dfn;i;i--)rk[c[mxlen[i]]--] = i; for(int i=dfn;i;i--) { if(t == 1)size[fa[rk[i]]] += size[rk[i]]; else size[fa[rk[i]]] = 1; } size[1] = 0; for(int i=dfn;i;i--) { sum[rk[i]] = size[rk[i]]; for(int j=0;j<26;j++) sum[rk[i]] += sum[tr[rk[i]][j]]; } } void dfs(int x,int k) { if(k <= size[x])return; k -= size[x]; for(int i=0;i<26;i++) { if(!tr[x][i])continue; if(k <= sum[tr[x][i]]) { putchar('a' + i); dfs(tr[x][i],k); return; }k -= sum[tr[x][i]]; } } int main() { #ifdef Ez3real fre(); #endif root = last = ++dfn; scanf("%s",s + 1);n = strlen(s + 1); t = read(),k = read(); for(int i=1;i<=n;i++)extend(s[i] - 'a'); build(); if(k > sum[1])puts("-1"); else dfs(root,k); }
bzoj4566 字符串
多次询问 $s[a,b]$ 的所有子串和 $s[c,d]$ 的所有子串的最长公共前缀的最大值
sol:建反串的后缀自动机,这样最长公共前缀就变成了最长公共后缀,对应的就是 LCA 的 mxlen
我们二分答案 $x$ ,先倍增找到 $d$ 在后缀树上的位置,然后维护一下 $d$ 的 $endpos$ 集合里有没有出现 $[a+x-1,b]$ 这一段子串即可
维护 $endpos$ 集合要线段树合并,然后我们发现好像 $c$ 是打酱油的。。。
要注意,线段树合并要新开一个节点,要不然会挂
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } void fre() { freopen("mydata.in","r",stdin); freopen("mydata.out","w",stdout); } const int maxn = 200010; int n,m; char s[maxn]; int tr[maxn][26],fa[maxn],mxlen[maxn],rt,dfn,last; int mp[maxn],pos[maxn]; void extend(int c) { int p = last,np = last = ++dfn; mxlen[np] = mxlen[p] + 1; for(;p && !tr[p][c];p = fa[p])tr[p][c] = np; if(!p)fa[np] = rt; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)fa[np] = q; else { int nq = ++dfn;mxlen[nq] = mxlen[p] + 1; memcpy(tr[nq],tr[q],sizeof(tr[nq])); fa[nq] = fa[q]; fa[q] = fa[np] = nq; for(;p && tr[p][c] == q;p = fa[p])tr[p][c] = nq; } } } int root[maxn << 1],ls[maxn << 6],rs[maxn << 6],val[maxn << 6],ToT; inline void Insert(int &x,int l,int r,int pos) { if(!x) x = ++ToT; if(l == r){val[x]++;return;} int mid = (l + r) >> 1; if(pos <= mid)Insert(ls[x],l,mid,pos); else Insert(rs[x],mid + 1,r,pos); val[x] = val[ls[x]] + val[rs[x]]; } inline int merge(int x,int y) { if(!x || !y)return x + y; val[x] += val[y]; //if(!ls[x] && !rs[x])return x; ls[x] = merge(ls[x],ls[y]); rs[x] = merge(rs[x],rs[y]); return x; /*if(!x || !y)return x + y; int np = ++ToT; ls[np] = merge(ls[y],ls[x]); rs[np] = merge(rs[x],rs[y]); val[np] = val[ls[np]] + val[rs[np]]; return np;*/ } inline int query(int x,int l,int r,int L,int R) { if(L <= l && r <= R)return val[x]; int mid = (l + r) >> 1,ans = 0; if(L <= mid)ans += query(ls[x],l,mid,L,R); if(R > mid)ans += query(rs[x],mid + 1,r,L,R); return ans; } int first[maxn],to[maxn << 1],nx[maxn << 1],cnt; int dep[maxn],anc[maxn][23]; inline void add(int u,int v){to[++cnt] = v;nx[cnt] = first[u];first[u] = cnt;} inline void dfs(int x) { for(int i=1;i<=22;i++) anc[x][i] = anc[anc[x][i - 1]][i - 1]; for(int i=first[x];i;i=nx[i]) { if(to[i] == anc[x][0])continue; anc[to[i]][0] = x;dep[to[i]] = dep[x] + 1; dfs(to[i]); root[x] = merge(root[x],root[to[i]]); } } inline int get_anc(int x,int k) // x zuxian { //if(!k)return 1; for(int i=22;~i;i--) if(mxlen[anc[x][i]] >= k) x = anc[x][i]; return x; } bool chk(int mid,int l,int r,int pos) { if(mid == 0)return 1; if(l > r)return 0; pos = get_anc(pos,mid); return query(root[pos],1,n,l,r); } int main() { #ifdef Ez3real fre(); #endif rt = last = ++dfn; n = read(),m = read(); scanf("%s",s + 1); reverse(s + 1,s + n + 1); for(int i=1;i<=n;i++) { extend(s[i] - 'a'); mp[last] = i; pos[i] = last; } for(int i=1;i<=dfn;i++) if(mp[i])Insert(root[i],1,n,mp[i]); for(int i=1;i<=dfn;i++)add(fa[i],i); dep[rt] = 1;dfs(rt); while(m--) { int a = read(),b = read(),c = read(),d = read(); swap(a,b);swap(c,d);a = n - a + 1,b = n - b + 1,c = n - c + 1,d = n - d + 1; int l = 0,r = min(b - a + 1,d - c + 1),ans = 0; //if(a > b || c > d){puts("0");continue;} while(l <= r) { int mid = (l + r) >> 1; if(chk(mid,a + mid - 1,b,pos[d]))l = mid + 1,ans = max(ans,mid); else r = mid - 1; } printf("%d\n",ans); } }
bzoj3926 诸神眷顾的幻想乡
一棵不超过 19 叉的树,每个点有一个颜色,颜色总共只有 10 种,树的大小一共只有 2000
对于任意两个树上的点 $(a,b)$ 我们称 $str_{(a,b)}$ 为从 $a$ 开始沿简单路径走到 $b$ 途径的每个点的颜色组成的序列
求有多少本质不同的 $str_{(a,b)}$
sol:Trie 树的广义后缀自动机
两点间的有向字符串可以视为以每个叶子节点为根构成的 Trie 树上的某条直链(祖先 -> 儿子)
对于这种问题我们可以建立每个 Trie 树的广义后缀自动机
因为叶子不超过 20 个,暴力即可
最后统计一下这个广义后缀自动机上有多少本质不同的子串,这就是模板题啦
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } const int maxn = 2010000; int n,c;LL ans; int col[maxn]; int first[maxn],to[maxn],nx[maxn],cnt; int ind[maxn]; inline void add(int u,int v) { to[++cnt] = v; nx[cnt] = first[u]; first[u] = cnt; } int last,root,dfn; int fa[maxn],mxlen[maxn],tr[maxn][15]; void extend(int c) { int p = last,np = last = ++dfn; mxlen[np] = mxlen[p] + 1; while(p && !tr[p][c])tr[p][c] = np,p = fa[p]; if(!p)fa[np] = root; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)fa[np] = q; else { int nq = ++dfn; mxlen[nq] = mxlen[p] + 1;memcpy(tr[nq],tr[q],sizeof(tr[nq]));fa[nq] = fa[q],fa[np] = fa[q] = nq; while(p && tr[p][c] == q)tr[p][c] = nq,p = fa[p]; } } } void dfs(int x,int fa) { extend(col[x]); int tmp = last; for(int i=first[x];i;i=nx[i]) if(to[i] != fa) { last = tmp; dfs(to[i],x); } } int main() { root = last = ++dfn; n = read(),c = read(); for(int i=1;i<=n;i++)col[i] = read(); for(int i=2;i<=n;i++) { int u = read(),v = read(); add(u,v);add(v,u); ind[u]++;ind[v]++; } for(int i=1;i<=n;i++) if(ind[i] == 1) { last = 1; dfs(i,0); } for(int i=1;i<=dfn;i++)ans += (mxlen[i] - mxlen[fa[i]]); cout<<ans; }
loj6401 字符串
有一个字符串 $S$,每个位置可能是好的或者坏的,定义一个子串是好的,当且仅当它包含了不超过 $k$ 个坏的位置,求有多少本质不同的好的子串
$|S| \leq 100000$
sol:每个子串找出最长的合法后缀,沿 parent 更新
#include<bits/stdc++.h> using namespace std; const int N=100010; int K; char s[N],b[N]; struct Suffix_Automaton{ static const int M=N<<1; int son[M][26],par[M],Mxlen[M],SAM_cnt,Deg[M],Q[M],limit[M]; int Extend(int p,int c){ int q=++SAM_cnt; Mxlen[q]=Mxlen[p]+1; while (p>0 && son[p][c]==0){ son[p][c]=q; p=par[p]; } if (p==0){ par[q]=1; }else{ int r=son[p][c]; if (Mxlen[r]==Mxlen[p]+1){ par[q]=r; }else{ int o=++SAM_cnt; par[o]=par[r]; par[q]=par[r]=o; Mxlen[o]=Mxlen[p]+1; memcpy(son[o],son[r],sizeof(son[o])); while (p>0 && son[p][c]==r){ son[p][c]=o; p=par[p]; } } } return q; } void build(){ int i,len=strlen(s),p=SAM_cnt=1,l=0,cnt=0; for (i=0;i<len;i++){ p=Extend(p,s[i]-'a'); cnt+=(b[i]=='0'); while (cnt>K){ cnt-=(b[l]=='0'); l++; } limit[p]=i-l+1; } for (i=2;i<=SAM_cnt;i++){ Deg[par[i]]++; } int L=1,R=0; for (i=1;i<=SAM_cnt;i++){ if (Deg[i]==0){ Q[++R]=i; } } long long Ans=0; while (L<=R){ int x=Q[L++],t=par[x]; Ans+=max(0,min(limit[x],Mxlen[x])-Mxlen[t]); if (t!=0){ limit[t]=max(limit[t],limit[x]); Deg[t]--; if (Deg[t]==0){ Q[++R]=t; } } } printf("%lld\n",Ans); } }SAM; int main(){ scanf("%s%s%d",s,b,&K); SAM.build(); return 0; }
loj6041 事情的相似度
一个 01 串,多次询问一段区间内的前缀的最长公共后缀
$n,q \leq 100000$
sol:
实质上是要求区间内两两 LCA 深度的最大值
离线,按右端点排序
每加入一个字母就在这个字母到根的路径上打标记
查询的时候沿查询节点往根跑,如果跑到一个有旧标记的点,则该点为旧标记和新标记的 LCA
每次贪心地更新更大的标记
用一个以询问左端点为下标的树状数组统计答案
然后发现从一个地方走到根这个事情复杂度不是很显然
写一个 LCT ,access 即可
#include<bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0,f = 1;char ch = getchar(); for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f; for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0'; return x * f; } const int maxn = 200010; int n,q,l[maxn],reh[maxn]; char s[maxn]; int root,dfn,last; int tr[maxn][2],par[maxn],mxlen[maxn]; void extend(int c,int id) { int p = last,np = last = ++dfn; reh[id] = np; mxlen[np] = mxlen[p] + 1; for(;p && !tr[p][c];p = par[p])tr[p][c] = np; if(!p)par[np] = root; else { int q = tr[p][c]; if(mxlen[q] == mxlen[p] + 1)par[np] = q; else { int nq = ++dfn; mxlen[nq] = mxlen[p] + 1; memcpy(tr[nq],tr[q],sizeof(tr[nq])); par[nq] = par[q]; par[q] = par[np] = nq; for(;p && tr[p][c] == q;p = par[p])tr[p][c] = nq; } } } vector<int> qs[maxn]; int c[maxn]; inline int lowbit(int x){return x & (-x);} inline int ask(int x){x = n - x + 1;int res = 0;for(;x;x -= lowbit(x))res = max(res,c[x]);return res;} inline void add(int x,int val){x = n - x + 1;for(;x <= n;x += lowbit(x))c[x] = max(c[x],val);} #define ls ch[x][0] #define rs ch[x][1] int ch[maxn][2],fa[maxn],val[maxn],tag[maxn],rev[maxn],st[maxn],top; inline void pushdown(int x) { if(!tag[x])return; if(ls)val[ls] = tag[ls] = tag[x]; if(rs)val[rs] = tag[rs] = tag[x]; tag[x] = 0; } inline int isroot(int x){return (ch[fa[x]][0] != x) && (ch[fa[x]][1] != x);} inline void rotate(int x) { int y = fa[x],z = fa[y]; int l = (ch[y][1] == x),r = l ^ 1; if(!isroot(y))ch[z][ch[z][1] == y] = x; fa[ch[x][r]] = y;fa[x] = z;fa[y] = x; ch[y][l] = ch[x][r];ch[x][r] = y; //pushup(y);pushup(x); } inline void splay(int x) { st[top = 1] = x; for(int i=x;!isroot(i);i=fa[i])st[++top] = fa[i]; for(int i=top;i;i--)pushdown(st[i]); while(!isroot(x)) { int y = fa[x],z = fa[y]; if(!isroot(y)) { if(ch[z][1] == y ^ ch[y][1] == x)rotate(x); else rotate(y); } rotate(x); } } inline void access(int x,int v) { int y; for(y=0;x;y = x,x = fa[x])splay(x),add(val[x],mxlen[x]),rs = y; tag[y] = val[y] = v; } int ans[maxn]; int main() { last = root = ++dfn; n = read(),q = read(); scanf("%s",s + 1); for(int i=1;i<=q;i++) { l[i] = read();int r = read(); qs[r].push_back(i); } for(int i=1;i<=n;i++)extend(s[i] - '0',i); for(int i=1;i<=dfn;i++)fa[i] = par[i]; for(int i=1;i<=n;i++) { access(reh[i],i); int m = qs[i].size(); for(int j=0;j<m;j++) { int now = qs[i][j]; ans[now] = ask(l[now]); } } for(int i=1;i<=q;i++)printf("%d\n",ans[i]); }
upd:有一个好东西叫做序列自动机,也就是兹磁识别一个串的所有子序列的自动机
具体实现的话很简单,用一个数组 $next_{(i,j)}$ 记录第 $i$ 位后出现的第一个字符 $j$ 出现的位置即可
这样很多在串上的题可以强行上序列。。。
luogu P4608 所有公共子序列问题
求两个串所有公共子序列的个数,位置不同算不同
sol:建一个序列自动机然后暴力 dp 即可
#include<bits/stdc++.h> #define MAXN 3002 #define BASE 1e9 using namespace std; inline int get(char s){return s<='Z'? s-'A':s-'a'+26;} inline char code(int x){return x<26? 'A'+x:'a'+x-26;} struct SEGAM { int t[MAXN][52],S,tot,f[52],last[MAXN]; SEGAM() { S=tot=1; for(int i=0;i<52;i++)f[i]=S; } void Insert(int x) { last[++tot]=f[x]; int i,j; for(i=0;i<52;i++) { for(j=f[i];j&&!t[j][x];j=last[j])t[j][x]=tot; } f[x]=tot; } }A,B; struct BigNum { int a[20],n; void print() { printf("%d",a[n-1]); for(int i=n-2;i>=0;i--)printf("%09d",a[i]); } }dp[MAXN][MAXN],one,zero; inline BigNum operator + (BigNum x,BigNum y) { x.n=max(x.n,y.n); for(int i=0;i<x.n;i++) { x.a[i]+=y.a[i]; if(x.a[i]>=BASE)x.a[i+1]++,x.a[i]-=BASE; } if(x.a[x.n])x.n++; return x; } bitset<MAXN> vis[MAXN]; BigNum Solve1(int a,int b) { if(!a||!b)return zero; if(vis[a][b])return dp[a][b]; vis[a][b]=1;dp[a][b]=one; for(int i=0;i<52;i++)dp[a][b]=dp[a][b]+Solve1(A.t[a][i],B.t[b][i]); return dp[a][b]; } char s[MAXN]; void Solve2(int a,int b,int n) { if(!a||!b)return; s[n]=0;printf("%s\n",s); for(int i=0;i<52;i++) { s[n]=code(i); Solve2(A.t[a][i],B.t[b][i],n+1); } } int N,M,K; char sa[MAXN],sb[MAXN]; int main() { scanf("%d%d%s%s%d",&N,&M,sa,sb,&K); one.n=one.a[0]=zero.n=1; for(int i=0;i<N;i++)A.Insert(get(sa[i])); for(int i=0;i<M;i++)B.Insert(get(sb[i])); if(K)Solve2(1,1,0); Solve1(1,1);dp[1][1].print(); return 0; }