题解 回忆树
传说中的对联算法?
另一篇题解
- 当出现给定一些模式串和一些文本串,询问某几个模式串被匹配了多少次时:
建出AC自动机,建出fail树
在自动机上跑文本串,每到一个节点就在对应的dfs序上+1
处理询问时直接查子树区间和即可
貌似也可以在广义SAM上跑,线段树合并维护right集合,然后查在给定区间内有多少个maxpos - 当出现给定一些模式串和一棵trie树,询问某几个模式串被匹配了多少次时:
建出AC自动机,建出fail树
在trie树上dfs,同步跑自动机,进入树上的一个节点时就在对应的dfs序上+1,回溯时-1
处理询问时直接查子树区间和即可 - 当出现给定一棵trie树和一些文本串,询问某条链上某个文本串被匹配了多少次时:
先将所有匹配分为三种:在左链匹配,在右链匹配,跨lca匹配
跨lca的匹配直接暴力是 \(O(|\text{文本串}|)\) 的
剩下的情况差分成两个根到点上文本串出现次数的查询
对文本串建出AC自动机,建出fail树,注意因为左右链上顺序相反文本串要正着插入一次倒着插入一次
在trie树上dfs,同步跑自动机,进入树上的一个节点时就在对应的dfs序上+1,回溯时-1
处理询问时直接查子树区间和即可
于是本题就是最后一种情况
复杂度 \(O(nlogn)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define ll long long
#define fir first
#define sec second
#define pb push_back
//#define int long long
int n, m, k;
char c[N], val[N], sta[N], sta2[N];
queue<int> q;
int head[N], nxt[N], size, top, top2, sumlen;
vector<pair<int, int>> ise[N], del[N];
int dep[N], lg[N], fa[25][N], tr[N][26], fail[N], ans[N], tot;
struct edge{int to, next; char val;}e[N<<1];
inline void add(int s, int t, char c) {e[++size].to=t; e[size].next=head[s]; e[size].val=c; head[s]=size;}
struct bit{
int bit[N];
inline void upd(int i, int dat) {for (; i<=sumlen; i+=i&-i) bit[i]+=dat;}
inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
}bit;
struct trie{
int head[N], in[N], siz[N], size, tot;
struct edge{int to, next;}e[N<<1];
trie(){memset(head, -1, sizeof(head));}
inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}
void dfs(int u) {
siz[u]=1;
in[u]=++tot;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
siz[u]+=siz[v];
}
}
void upd(int p, int dat) {bit.upd(in[p], dat);}
int qsum(int p) {return bit.query(in[p]+siz[p]-1)-bit.query(in[p]-1);}
}trie;
void dfs1(int u, int pa) {
for (int i=1; i<25; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==pa) continue;
val[v]=e[i].val;
dep[v]=dep[u]+1;
fa[0][v]=u;
dfs1(v, u);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
int anc(int a, int t) {
for (int i=24; ~i; --i) if (t>=1<<i)
a=fa[i][a], t-=1<<i;
return a;
}
int kmp() {
int k=strlen(c+1), ans=0;
nxt[1]=0;
for (int i=2,j=0; i<=k; ++i) {
while (j && c[i]!=c[j+1]) j=nxt[j];
if (c[i]==c[j+1]) ++j;
nxt[i]=j;
}
for (int i=1,j=0; i<=top; ++i) {
while (j && (j==k||sta[i]!=c[j+1])) j=nxt[j];
if (sta[i]==c[j+1]) ++j;
if (j==k) ++ans;
}
return ans;
}
int insert(char* c) {
// cout<<"ins: "<<endl;
int u=0, *t;
for (; *c; ++c,u=*t) {
t=&tr[u][*c-'a'];
if (!*t) *t=++tot;
}
return u;
}
void build() {
int u=0;
for (int i=0; i<26; ++i)
if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=u, trie.add(u, tr[u][i]);
else tr[u][i]=u;
while (q.size()) {
u=q.front(); q.pop();
for (int i=0; i<26; ++i)
if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=tr[fail[u]][i], trie.add(tr[fail[u]][i], tr[u][i]);
else tr[u][i]=tr[fail[u]][i];
}
}
void dfs2(int u, int p, int fa) {
// cout<<"u: "<<u<<' '<<p<<endl;
trie.upd(p, 1);
for (auto it:ise[u]) ans[it.sec]+=trie.qsum(it.fir); //, cout<<"add: "<<it.fir<<' '<<it.sec<<' '<<trie.qsum(it.fir)<<endl;
for (auto it:del[u]) ans[it.sec]-=trie.qsum(it.fir); //, cout<<"del: "<<it.fir<<' '<<it.sec<<' '<<trie.qsum(it.fir)<<endl;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs2(v, tr[p][e[i].val-'a'], u);
}
trie.upd(p, -1);
}
signed main()
{
scanf("%d%d", &n, &m);
memset(head, -1, sizeof(head));
for (int i=1,u,v; i<n; ++i) {
scanf("%d%d%s", &u, &v, c);
add(u, v, *c); add(v, u, *c);
}
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1; dfs1(1, 0);
for (int i=1,u,v,k,t1,t2,s,t,g; i<=m; ++i) {
scanf("%d%d%s", &u, &v, c+1);
// cout<<"uv: "<<u<<' '<<v<<endl;
k=strlen(c+1);
t1=insert(c+1); reverse(c+1, c+k+1);
t2=insert(c+1); reverse(c+1, c+k+1);
g=lca(u, v);
s=anc(u, max(dep[u]-dep[g]-(k-1), 0));
t=anc(v, max(dep[v]-dep[g]-(k-1), 0));
// cout<<"st: "<<s<<' '<<t<<endl;
if (u!=g && v!=g) {
top=top2=0;
for (int now=s; now!=g; now=fa[0][now]) sta[++top]=val[now];
for (int now=t; now!=g; now=fa[0][now]) sta2[++top2]=val[now];
while (top2) sta[++top]=sta2[top2--];
// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
ans[i]+=kmp();
// cout<<"add: "<<kmp()<<endl;
}
ise[u].pb({t2, i}); del[s].pb({t2, i});
ise[v].pb({t1, i}); del[t].pb({t1, i});
}
build(); trie.dfs(0); sumlen=trie.tot; dfs2(1, 0, 0);
for (int i=1; i<=m; ++i) printf("%d\n", ans[i]);
return 0;
}