[后缀自动机初探]
具体的定义及基本应用构造见2012年冬令营陈老师的ppt
这篇博文的题目对于刚刚接触的同学有可能偏难,建议可以用后缀自动机做一下以前做过的后缀数组的题目。不过题目都是很好的!
[POJ 2774]Long Long Message
后缀自动机的模式匹配。
类似kmp一样的往上跳
#include <algorithm> #include <iostream> #include <cstring> #include <cstdio> #define maxn 100010 using namespace std; char str[maxn]; struct Node{ int len, link, nxt[26]; }st[maxn << 1]; int root, size, last; void init(){ root = size = last = 0; st[root].len = 0; st[root].link = -1; } void Extend(int c){ int p = last, cur = ++ size; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } int main(){ init(); scanf("%s", str + 1); int n = strlen(str + 1); for(int i = 1; i <= n; i ++) Extend(str[i] - 'a'); scanf("%s", str + 1); n = strlen(str + 1); int nw = root, cur = 0, ans = 0; for(int i = 1; i <= n; i ++){ int c = str[i] - 'a'; if(st[nw].nxt[c])cur ++, nw = st[nw].nxt[c]; else{ while(~nw && st[nw].nxt[c] == 0) nw = st[nw].link; if(nw == -1)nw = root, cur = 0; else cur = st[nw].len + 1, nw = st[nw].nxt[c]; } ans = max(ans, cur); } printf("%d\n", ans); return 0; }
[BZOJ 3238][AHOI 2013]差异
给一个字符串,求∑ ∑ len[i] + len[j] - 2 * lcp(i, j)
差异这道题目给初学sam的窝很大启发
求两个子串的lcp的方法:
将原串逆序插入后缀自动机即得后缀树。
将两个点的LCA求出,LCA对应的len值即为LCP的长度。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #define maxn 1000010 using namespace std; int n; typedef long long ll; char s[maxn]; struct Node{int len, link, nxt[26];}st[maxn]; int root, size, last; void init(){ root = size = last = 0; st[root].link = -1; } void Extend(char ch){ int p = last, cur = ++ size, c = ch - 'a'; st[cur].len = st[last].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[p].len + 1 == st[q].len) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } long long ans; int mark[maxn]; struct Edge{ int to, next; }edge[maxn]; int h[maxn], cnt; void add(int u, int v){ cnt ++; edge[cnt].to = v; edge[cnt].next = h[u]; h[u] = cnt; } int dep[maxn]; ll dp[maxn], ret; void dfs(int u){ dp[u] = mark[u]; ll sum = dp[u] * dp[u]; for(int i = h[u]; i; i = edge[i].next){ int v = edge[i].to; dfs(v); dp[u] += dp[v]; sum += dp[v] * dp[v]; } ret += (dp[u] * dp[u] - sum) * st[u].len; } long long solve(){ int now = root; for(int i = n; i >= 1; i --){ now = st[now].nxt[s[i] - 'a']; mark[now] ++; } for(int i = 1; i <= size; i ++) add(st[i].link, i); ret = 0; dfs(root); return ret; } ll p[maxn]; int main(){ scanf("%s", s + 1); n = strlen(s + 1); init(); for(int i = n; i >= 1; i --) Extend(s[i]); for(int i = 1; i <= n; i ++) p[i] = p[i - 1] + i; for(int i = n; i >= 1; i --) ans += 1ll * i * (i - 1) + p[i - 1]; printf("%lld\n", ans - solve()); return 0; }
[BZOJ 3676][APIO 2014]回文串
考虑一个只包含小写拉丁字母的字符串s。我们定义s的一个子串t的“出现值”为t在s中的出现次数乘以t的长度。请你求出s的所有回文子串中的最大出现值。
后缀自动机上的倍增(一种常用的技巧),当然了,可以用pam
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #define maxn 600010 using namespace std; char s[maxn]; int n; long long ans; struct Node{int len, link, nxt[26], size;}st[maxn]; int root, size, last; void init(){ root = size = last = 0; st[root].len = 0; st[root].link = -1; } int anc[maxn][20], pos[maxn]; void Extend(char ch, int part){ int p, cur = ++ size, c = ch - 'a'; st[cur].len = st[last].len + 1; st[cur].size = 1; pos[part] = cur; for(p = last; ~p && !st[p].nxt[c]; p = st[p].link) st[p].nxt[c] = cur; pos[part] = cur; st[cur].size = 1; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; st[clone].size = 0; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } int t[maxn], w[maxn]; void build(){ memset(anc, -1, sizeof anc); for(int i = 1; i <= size; i ++) anc[i][0] = st[i].link; for(int j = 1; 1 << j <= size; j ++) for(int i = 1; i <= size; i ++){ int a = anc[i][j - 1]; if(~a)anc[i][j] = anc[a][j - 1]; } for(int i = 1; i <= size; i ++) w[st[i].len] ++; for(int i = 1; i <= size; i ++) w[i] += w[i - 1]; //for(int i = 1; i <= size; i ++) for(int i = size; i >= 1; i --) t[w[st[i].len] --] = i; for(int i = size; i; i --) st[st[t[i]].link].size += st[t[i]].size; } void update(int l, int r){ int t = pos[r]; for(int i = 18; i >= 0; i --){ if(~anc[t][i]){ int to = anc[t][i]; if(st[to].len >= r - l + 1) t = to; } } ans = max(ans, 1ll * st[t].size * (r - l + 1)); } int r[maxn]; void solve(){ s[0] = '*'; s[n + 1] = '#'; init(); for(int i = 1; i <= n; i ++) Extend(s[i], i); build(); int mx = 0, p = 0; for(int i = 1; i <= n; i ++){ if(i < mx)r[i] = min(r[2 * p - i - 1], mx - i); else r[i] = 0; while(s[i + r[i] + 1] == s[i - r[i]]){ r[i] ++; update(i - r[i] + 1, i + r[i]); } if(r[i] + i > mx)mx = r[i] + i, p = i; } mx = 0, p = 0; for(int i = 1; i <= n; i ++){ if(i < mx){r[i] = min(r[2 * p - i], mx - i - 1);} else {r[i] = 1;update(i, i);} while(s[i + r[i]] == s[i - r[i]]){ r[i] ++; update(i - r[i] + 1, i + r[i] - 1); } if(r[i] + i > mx)mx = r[i] + i, p = i; } printf("%lld\n", ans); } int main(){ scanf("%s", s + 1); n = strlen(s + 1); solve(); return 0; }
其实还有一道题目相关--HEOI2015最短不公共子串
[BZOJ 3998]弦论
对于一个给定长度为N的字符串,求它的第K小子串是什么。
第一行是一个仅由小写英文字母构成的字符串S
后缀自动机(大概是子串计数一道很好的题目)
SAM上每一个节点代表一条路径从根出发到这里的字符串。
所以每一个节点++就代表一个不同的子串
right集合:parent树上所对应的叶子节点的个数。
然后我们要把right集合累加起来当做位置不同的子串算多个的个数。
然后DFS。(26分?雾。。)
后缀自动机的状态right集合大小是其在parent树中子树的叶子节点数量,代表这个状态所代表的字串出现次数。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #define maxn 1000010 using namespace std; char c[maxn]; struct Node{ int len, link, nxt[26]; }st[maxn]; int root, last, size; void init(){ root = last = size = 1; st[root].len = 0; st[root].link = -1; } int t[maxn], w[maxn], sum[maxn], val[maxn]; void Extend(char ch){ int p, cur = ++ size, c = ch - 'a'; st[cur].len = st[last].len + 1; val[cur] = 1; for(p = last; ~p && !st[p].nxt[c]; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } }last = cur; } int x, k; void prepare(){ for(int i = 1; i <= size; i ++)w[st[i].len] ++; for(int i = 1; i <= size; i ++)w[i] += w[i - 1]; for(int i = size; i >= 1; i --)t[w[st[i].len] --] = i; for(int i = size; i >= 1; i --){ int now = t[i]; if(x == 1)val[st[now].link] += val[now]; else val[now] = 1; } val[1] = 0; for(int i = size; i; i --){ int now = t[i];sum[now] = val[now]; for(int j = 0; j < 26; j ++) sum[now] += sum[st[now].nxt[j]]; } } void dfs(int x, int k){ if(k <= val[x])return; k -= val[x];int v; for(int i = 0; i < 26; i ++){ if(v = st[x].nxt[i]){ if(k <= sum[v]){ putchar(i + 'a'); dfs(v, k); return; } k -= sum[v]; } } } int main(){ init(); scanf("%s", c + 1); int n = strlen(c + 1); for(int i = 1; i <= n; i ++) Extend(c[i]); scanf("%d%d", &x, &k); prepare(); if(sum[1] < k)printf("-1"); else dfs(1, k); return 0; }
[NOI 2015]品酒大会
给定一个字符串,求出这个字符串中所有长度为i(0<=i<n)两两相等的子串个数和给定value[p]*value[q]的最大值(p,q为左端点)
考虑后缀自动机。
如果将所有节点(包括clone节点)的路径数都赋为1。做路径计数的话应该是所有本质不同的子串的计数。
如果只将原字符串遍历到的节点(不包括clone节点)的值赋为1,做路径计数的话应该是所有节点的right集合。
根据[TJOI 弦论]如果把right集合累加起来,就可以得知子串相同但是位置不同算多个的子串的个数
right集合代表的是什么?
是指parent树上这个点子树的叶节点的个数。是这个状态的子串在原串中出现的次数
这道题目:选择两个子串,它们的LCP等于parent树上的LCA的len值。
我们要统计的是right集合(代表这个状态的字符串出现的次数)。也就是这棵子树中的叶节点的数目。
关于Right集合:
定义:一个子串str在母串S中所有出现位置的右端点。如子串str在S中出现位置为[l1,r1),[l2,r2),...,[ln,r3),则 str的Right集合为{r1..rn}。会有一些子串的Right集合相同,其中最长的len为MAX(s),最短的为MIN(s)
#include <bits/stdc++.h> #define maxn 600010 using namespace std; int n; char str[maxn]; typedef long long ll; ll ans[maxn], mx[maxn][2], mn[maxn][2], s[maxn], siz[maxn], t; int val[maxn]; struct Node{ int nxt[26], len, link; }st[maxn]; int last, size, root; void init(){ root = last = size = 0; st[root].link = -1; st[root].len = 0; } void Extend(char ch, int Id){ int cur = ++ size, p = last, c = ch - 'a'; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; mx[cur][0] = mn[cur][0] = val[Id]; siz[cur] = 1; } struct Edge{ int to, next; }edge[maxn << 1]; int h[maxn], cnt; void add(int u, int v){ cnt ++; edge[cnt].to = v; edge[cnt].next = h[u]; h[u] = cnt; } inline void upd(ll& a, ll b){ if(b > a) a = b; } const ll inf = 1e9+1; void DFS(int u){ int len = st[u].len; for(int i = h[u]; i; i = edge[i].next){ int v = edge[i].to; DFS(v); s[len] += siz[u] * siz[v]; siz[u] += siz[v]; if(mx[v][0] >= mx[u][0]){ mx[u][1] = mx[u][0]; mx[u][0] = mx[v][0]; } else mx[u][1] = max(mx[u][1], mx[v][0]); if(mn[v][0] <= mn[u][0]){ mn[u][1] = mn[u][0]; mn[u][0] = mn[v][0]; } else mn[u][1] = min(mn[u][1], mn[v][0]); mx[u][1] = max(mx[u][1], mx[v][1]); mn[u][1] = min(mn[u][1], mn[v][1]); } if(mx[u][1] > -inf && mx[u][1] > -inf)upd(ans[len], mx[u][0] * mx[u][1]); if(mn[u][1] < inf && mn[u][0] < inf)upd(ans[len], mn[u][0] * mn[u][1]); } int main(){ init(); scanf("%d", &n); scanf("%s", str+1); for(int i = 1; i <= n; i ++) scanf("%d", &val[i]); for(int i = 0; i <= 2 * n; i ++){ ans[i] = -1ll << 61; mx[i][0] = mx[i][1] = -0x7fffffff; mn[i][0] = mn[i][1] = 0x7fffffff; } for(int i = n; i >= 1; i --) //for(int i = 1; i <= n; i ++) Extend(str[i], i); for(int i = 1; i <= size; i ++) add(st[i].link, i); DFS(root); for(int i = n-2; i >= 0; i --) ans[i] = max(ans[i+1], ans[i]), s[i] += s[i+1]; for(int i = 0; i < n; i ++){ if(s[i] == 0)puts("0 0"); else printf("%lld %lld\n", s[i], ans[i]); } return 0; }
[BZOJ 4310]跳蚤
题目请点上面的链接
详细题解在这里
如何求第k大的子串?
26分?逐位确定
如果k大于这个儿子的个数就减掉
否则k-=当前的字符串值(如果是本质不同的字符串减1,否则减掉当前状态代表的值)然后转移now即可
当k=0时停止。
#include <bits/stdc++.h> #define maxn 200010 using namespace std; int n, k; struct Node{int len, link, nxt[26];}st[maxn]; long long s[maxn]; char str[maxn]; int root, size, last; void init(){ root = last = size = 0; st[root].link = -1; st[root].len = 0; } void Extend(char ch){ int cur = ++ size, p = last, c = ch - 'a'; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } bool vis[maxn]; long long DFS(int u){ if(vis[u])return s[u]; s[u] = vis[u] = 1; for(int i = 0; i < 26; i ++) if(st[u].nxt[i]) s[u] += DFS(st[u].nxt[i]); return s[u]; } int m; char ans[maxn]; void Getstring(long long k){ int now = root, t; m = 0; while(true){ for(int i = 0; i < 26; i ++){ if((t = st[now].nxt[i]) == 0)continue; if(k > s[t]) k -= s[t]; else{ now = t, k --; ans[++ m] = i + 'a'; if(k == 0)return; break; } } } } unsigned long long bases[maxn], hash1[maxn], hash2[maxn]; #define base 13131 bool pd(int i, int len){ if(len == 0)return true; return hash1[1] - hash1[1+len] * bases[len] == hash2[i] - hash2[i+len] * bases[len]; } bool cmp(int i, int j){ if(str[i] < ans[1])return true; if(str[i] > ans[1])return false; int l = 1, r = min(m, j-i+1); while(l < r){ int mid = l + (r - l + 1) / 2; if(pd(i, mid-1)){ if(str[i+mid-1] < ans[mid])return true; if(str[i+mid-1] > ans[mid])return false; l = mid; } else r = mid - 1; } if(str[i+r-1] < ans[r])return true; if(str[i+r-1] > ans[r])return false; if(j-i+1 > m)return false; return true; } bool check(long long where){ Getstring(where); hash1[m+1] = 0; for(int i = m; i >= 1; i --) hash1[i] = hash1[i+1] * base + ans[i]; int pos = n, cnt = 0; for(int i = n; i; i = pos){ while(pos && cmp(pos, i)) pos --; cnt ++; if(cnt > k || pos == i)return false; }return true; } int main(){ scanf("%d%s", &k, str+1); init(); n = strlen(str+1); for(int i = 1; i <= n; i ++) Extend(str[i]); DFS(root); bases[0] = 1; for(int i = n; i >= 1; i --) hash2[i] = hash2[i+1] * base + str[i]; for(int i = 1; i <= n; i ++) bases[i] = bases[i-1] * base; long long l = 1, r = s[root]; while(l < r){ long long mid = l + r >> 1; if(check(mid)) r = mid; else l = mid + 1; } Getstring(r); for(int i = 1; i <= m; i ++) putchar(ans[i]); return 0; }
[HAOI 2016]找相同字符
我写了一个非常鬼(ma)畜(fan)的做法。。
建立一个a串的后缀自动机,把b串放上去跑,在跑到的节点上累加答案
注意到后缀自动机一个节点包括的状态有很多(根本没注意)
状态有(len - fa_len)这么多种
而且b串跑的长度并不是当前状态的len
额外记录一下
其实建个广义后缀自动机直接统计就好了啊喂
#define MAXN 500000 #include <bits/stdc++.h> using namespace std; typedef long long ll; int n1, n2; char a[MAXN], b[MAXN]; int root, last, size; struct Node { int len, link, nxt[26]; } st[MAXN]; void init() { root = last = size = 0; st[root].link = -1; st[root].len = 0; } void Extend(int c) { int cur = ++ size, p = last; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else { int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else { int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } int t[MAXN], w[MAXN]; ll s[MAXN], sz[MAXN]; int main() { freopen("find_2016.in", "r", stdin); freopen("find_2016.out", "w", stdout); scanf("%s%s", a+1, b+1); n1 = strlen(a+1), n2 = strlen(b+1); init(); for(int i = 1; i <= n1; ++ i) Extend(a[i]-'a'); int cur = root, step = 0; for(int i = 1; i <= n1; ++ i) cur = st[cur].nxt[a[i]-'a'], s[cur] ++; for(int i = 1; i <= size; ++ i) w[st[i].len] ++; for(int i = 1; i <= size; ++ i) w[i] += w[i-1]; for(int i = 1; i <= size; ++ i) t[w[st[i].len] --] = i; for(int i = size; i >= 1; -- i) s[st[t[i]].link] += s[t[i]]; for(int i = 1; i <= size; ++ i) sz[i] = s[i]; for(int i = 1; i <= size; ++ i) s[i] = s[i] * (st[i].len-st[st[i].link].len); s[cur = root] = 0; for(int i = 1; i <= size; ++ i) s[t[i]] += s[st[t[i]].link]; long long ans = 0; for(int i = 1; i <= n2; ++ i) { int c = b[i]-'a'; if(st[cur].nxt[c]) cur = st[cur].nxt[c], step ++; else { while(~cur && !st[cur].nxt[c]) cur = st[cur].link; if(~cur) step = st[cur].len + 1, cur = st[cur].nxt[c]; else cur = root, step = 0; } ans += s[st[cur].link] + sz[cur]*(step - st[st[cur].link].len); } printf("%lld\n", ans); return 0; }
[BZOJ 3277] 串
现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。
Sol:
建立广义后缀自动机,然后给每个串打上标记(不卡时我就暴力改了一下)
扫一遍每一个串,统计当前字符结尾的子串>=k的个数,记忆搜即可
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #define maxn 200010 using namespace std; int n, k; char s[maxn]; int in[maxn], out[maxn]; struct Node{ int len, link, nxt[27]; }st[maxn << 1]; int root, size, last; void init(){ last = root = size = 0; st[root].link = -1; st[root].len = 0; } void Extend(int c){ int p = last, q = st[p].nxt[c]; if(q){ if(st[q].len == st[p].len + 1) last = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = clone; last = clone; } return; } int cur = ++ size; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[cur].link = st[q].link = clone; } } last = cur; } int vis[maxn]; int total_vis[maxn]; void update(int rt, int y){ while(~rt){ if(vis[rt] == y)break; total_vis[rt] ++; vis[rt] = y; rt = st[rt].link; } } long long val[maxn]; long long dfs(int rt){ if(vis[rt])return val[rt]; vis[rt] = true; long long ret = 0; if(~st[rt].link)ret = dfs(st[rt].link); if(total_vis[rt] >= k) ret += (long long)st[rt].len - (st[rt].link == -1 ? 0 : st[st[rt].link].len); return val[rt] = ret; } int main(){ init(); scanf("%d%d", &n, &k); for(int i = 1; i <= n; i ++){ in[i] = out[i-1] + 1; scanf("%s", s+in[i]); out[i] = strlen(s+in[i]) + in[i] - 1; last = root; for(int j = in[i]; j <= out[i]; j ++) Extend(s[j] - 'a'); } for(int i = 1; i <= n; i ++){ int now = root; for(int j = in[i]; j <= out[i]; j ++) now = st[now].nxt[s[j] - 'a'], update(now, i); } memset(vis, 0, sizeof vis); for(int i = 1; i <= n; i ++){ long long ans = 0; int now = root; for(int j = in[i]; j <= out[i]; j ++){ now = st[now].nxt[s[j] - 'a']; ans += dfs(now); } printf("%I64d ", ans); } return 0; }
[BZOJ 2806]Cheat
小强和阿米巴是好盆友~~
我们可以预处理出l[i]表示以i结尾的后缀在所有字符串中的最长匹配长度
如何得到呢?广义后缀自动机。
剩下用单调队列优化dp就可以了。
#include <bits/stdc++.h> using namespace std; #define maxn 1100010 int n, m; char str[maxn]; int last, root, size; struct Node{int link, len, nxt[2];}st[maxn << 1]; void init(){ last = root = size = 0; st[root].len = 0; st[root].link = -1; } void Extend(int c){ int p = last, q; if(q = st[p].nxt[c]){ if(st[q].len == st[p].len + 1) last = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = clone; last = clone; }return; } int cur = ++ size; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } //dp[i] = max(dp[j] + match(i, j), dp[i-1]) //max(dp[j] - j) + i j∈[i - st[nw].len, i - L] int dp[maxn], que[maxn], val[maxn], head, tail; int len, l[maxn]; void match(){ len = strlen(str + 1); int nw = root, cur = 0; for(int i = 1; i <= len; i ++){ int c = str[i] == '1'; if(st[nw].nxt[c])cur ++, nw = st[nw].nxt[c]; else{ while(~nw && !st[nw].nxt[c]) nw = st[nw].link; if(nw == -1) nw = root, cur = 0; else cur = st[nw].len + 1, nw = st[nw].nxt[c]; } l[i] = cur; } } bool check(int L){ dp[0] = 0; head = tail = 0; for(int i = 1; i <= len; i ++){ dp[i] = dp[i - 1]; int p = i - L; if(p >= 0){ int v = dp[p] - p; while(head < tail && v > val[tail - 1]) tail --; que[tail] = p; val[tail] = v; tail ++; } while(head < tail && que[head] + l[i] < i) head ++; if(head < tail)dp[i] = max(dp[i], val[head] + i); } return 10 * dp[len] >= 9 * len; } void Getans(){ int l = 0, r = len; while(l < r){ int mid = l + (r - l + 1) / 2; if(check(mid)) l = mid; else r = mid - 1; } printf("%d\n", l); } int main(){ init(); scanf("%d%d", &n, &m); for(int i = 1; i <= m; i ++){ scanf("%s", str + 1); len = strlen(str + 1), last = root; for(int j = 1; j <= len; j ++) Extend(str[j] == '1'); } for(int i = 1; i <= n; i ++){ scanf("%s", str + 1); match(); Getans(); } return 0; }
[BZOJ 1396] 识别子串
用后缀自动机搞出出现了一次的子串(其实就是求每个节点的Right)
记录每个节点的r(及他在原串中出现的pos位置,由于所求是Right=1的位置,所以pos唯一)
每一个节点的长度区间为[fa[len] + 1, len],即长度最小以及最大的[min, max]
发现有两种更新方式,在长度为[len - fa[len] + 1,len]这段区间要用len-fa[len]+1, len-fa[len]+2,.......,len来更新,在长度为[0, len - fa[len]]这一段区间要用fa[len] + 1来更新
线段树维护一下就可以了
#include <bits/stdc++.h> #define maxn 100010 using namespace std; struct Node{int len, link, nxt[26], r;}st[maxn << 1]; int root, last, size; void init(){ root = last = size = 0; st[root].link = -1; st[root].len = 0; } void Extend(int c, int pos){ int p = last, cur = ++ size; st[cur].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == 0; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ int q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } st[cur].r = pos; last = cur; } typedef long long ll; ll dp[maxn << 1]; int d[maxn << 1], w[maxn << 1]; char str[maxn]; int n; int t[maxn << 2], lazy[maxn << 2]; #define lc id << 1 #define rc id << 1 | 1 void pushdown(int id, int l, int r){ if(lazy[id] <= n){ int mid = l + r >> 1; lazy[lc] = min(lazy[lc], lazy[id]); lazy[rc] = min(lazy[rc], lazy[id] - (mid - l + 1)); lazy[id] = n << 1; } } void build(int id, int l, int r){ t[id] = n; if(l == r){lazy[id] = n;return;} lazy[id] = n << 1; int mid = l + r >> 1; build(lc, l, mid); build(rc, mid + 1, r); } void update(int id, int l, int r, int x, int y, int val){ if(l == x && r == y){ lazy[id] = min(lazy[id], val); return; } pushdown(id, l, r); int mid = l + r >> 1; if(y <= mid)update(lc, l, mid, x, y, val); else if(x > mid)update(rc, mid + 1, r, x, y, val); else update(lc, l, mid, x, mid, val), update(rc, mid + 1, r, mid + 1, y, val - (mid - x + 1)); } void update2(int id, int l, int r, int x, int y, int val){ if(l == x && r == y){ t[id] = min(t[id], val); return; } int mid = l + r >> 1; if(y <= mid)update2(lc, l, mid, x, y, val); else if(x > mid)update2(rc, mid + 1, r, x, y, val); else update2(lc, l, mid, x, mid, val), update2(rc, mid + 1, r, mid + 1, y, val); } void ask(int id, int l, int r){ if(l == r){ printf("%d\n", min(lazy[id], t[id])); return; } pushdown(id, l, r); t[lc] = min(t[lc], t[id]); t[rc] = min(t[rc], t[id]); int mid = l + r >> 1; ask(lc, l, mid); ask(rc, mid + 1, r); } int main(){ init(); scanf("%s", str + 1); n = strlen(str + 1); for(int i = 1; i <= n; i ++) Extend(str[i] - 'a', i); int cur = root; for(int i = 1; i <= n; i ++){ cur = st[cur].nxt[str[i] - 'a']; dp[cur] = 1; } for(int i = 1; i <= size; i ++)w[st[i].len] ++; for(int i = 1; i <= size; i ++)w[i] += w[i - 1]; for(int i = size; i >= 1; i --)d[w[st[i].len] --] = i; for(int i = size; i >= 1; i --)dp[st[d[i]].link] += dp[d[i]]; build(1, 1, n); for(int i = 1; i <= size; i ++){ if(dp[i] == 1){ int l = st[i].r - st[i].len + 1, r = st[i].r - st[st[i].link].len; update(1, 1, n, l, r, st[i].len); if(r + 1 <= st[i].r)update2(1, 1, n, r + 1, st[i].r, st[st[i].link].len + 1); } } ask(1, 1, n); return 0; }
最后对于多串,我们还有广义后缀自动机~
具体可以见这篇博文的E题
题目还有[BZOJ 3926][ZJOI 2015]诸神眷顾的幻想乡
[BZOJ 2780][SPOJ 8093] Sevenk Love Oimaster
abcabcabc --------字符串集合
aaa
aafe
abc --------询问字符串
a
ca
1 3 1
#include <bits/stdc++.h> #define maxn 500000 using namespace std; int n, m; char s[maxn]; struct Edge_{int to, next;}; int In[maxn], Out[maxn], dfs_clock, dfn[maxn], ans[maxn]; vector<int>nxt[maxn]; int vis[maxn]; namespace BIT{ int t[maxn]; #define lowbit(i) i&(~i+1) void update(int pos, int val){ if(!pos)return; for(int i = pos; i <= dfs_clock; i += lowbit(i)) t[i] += val; } int ask(int pos){ if(!pos)return 0; int ret = 0; for(int i = pos; i; i -= lowbit(i)) ret += t[i]; return ret; } } struct Edge{ Edge_ edge[maxn]; int h[maxn], cnt; void add(int u, int v){ cnt ++; edge[cnt].to = v; edge[cnt].next = h[u]; h[u] = cnt; } void dfs(int u){ In[u] = ++ dfs_clock; dfn[dfs_clock] = u; for(int i = h[u]; i; i = edge[i].next) dfs(edge[i].to); Out[u] = dfs_clock; } void solve(){ for(int i = dfs_clock; i; i --){ int now = dfn[i]; for(int j = h[now]; j; j = edge[j].next){ int v = edge[j].to; if(vis[v])nxt[i].push_back(vis[v]); vis[v] = i; } } for(int i = 1; i <= n; i ++) BIT::update(vis[i], 1); } }A, B; struct Node{ int len, link; map<int, int>nxt; }st[maxn]; int root, size, last; void init(){ root = size = last = 0; st[root].len = 0; st[root].link = -1; } void Extend(char ch, int Id){ int c = ch - 'a', p = last, q = st[p].nxt[c]; if(q){ if(st[q].len == st[p].len + 1) last = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = clone; last = clone; } } else{ int cur = ++ size; st[cur].len = st[p].len + 1; for(; ~p && !st[p].nxt[c]; p = st[p].link) st[p].nxt[c] = cur; if(p == -1) st[cur].link = root; else{ q = st[p].nxt[c]; if(st[q].len == st[p].len + 1) st[cur].link = q; else{ int clone = ++ size; st[clone] = st[q]; st[clone].len = st[p].len + 1; for(; ~p && st[p].nxt[c] == q; p = st[p].link) st[p].nxt[c] = clone; st[q].link = st[cur].link = clone; } } last = cur; } A.add(last, Id); } struct opt{ int l, r, id; bool operator<(const opt& k)const{ if(l != k.l)return l < k.l; return r < k.r; } }q[maxn]; int main(){ init(); scanf("%d%d", &n, &m); for(int i = 1; i <= n; i ++){ scanf("%s", s+1); int N = strlen(s+1); last = root; for(int j = 1; j <= N; j ++) Extend(s[j], i); } for(int i = 1; i <= size; i ++) B.add(st[i].link, i); B.dfs(root); A.solve(); int tot = 0; for(int i = 1; i <= m; i ++){ scanf("%s", s+1); int N = strlen(s+1), now = root; bool flag = true; for(int j = 1; j <= N; j ++){ int p = s[j] - 'a'; if(!st[now].nxt[p]){ flag = false; break; } now = st[now].nxt[p]; } if(flag){ ++ tot; q[tot].l = In[now]; q[tot].r = Out[now]; q[tot].id = i; } } sort(q+1, q+1+tot); int l = 1; for(int i = 1; i <= tot; i ++){ while(l < q[i].l && l < dfs_clock){ for(int j = 0; j < nxt[l].size(); j ++) BIT::update(nxt[l][j], 1); l ++; } ans[q[i].id] = BIT::ask(q[i].r) - BIT::ask(q[i].l-1); } for(int i = 1; i <= m; i ++) printf("%d\n", ans[i]); return 0; }
Sam终极boss:[BZOJ 3145][Feyat cup 1.5]Str
具体解题报告在这个博客中有