题解 回忆树

传送门

传说中的对联算法?
另一篇题解

  • 当出现给定一些模式串和一些文本串,询问某几个模式串被匹配了多少次时:
    建出AC自动机,建出fail树
    在自动机上跑文本串,每到一个节点就在对应的dfs序上+1
    处理询问时直接查子树区间和即可
    貌似也可以在广义SAM上跑,线段树合并维护right集合,然后查在给定区间内有多少个maxpos
  • 当出现给定一些模式串和一棵trie树,询问某几个模式串被匹配了多少次时:
    建出AC自动机,建出fail树
    在trie树上dfs,同步跑自动机,进入树上的一个节点时就在对应的dfs序上+1,回溯时-1
    处理询问时直接查子树区间和即可
  • 当出现给定一棵trie树和一些文本串,询问某条链上某个文本串被匹配了多少次时:
    先将所有匹配分为三种:在左链匹配,在右链匹配,跨lca匹配
    跨lca的匹配直接暴力是 \(O(|\text{文本串}|)\)
    剩下的情况差分成两个根到点上文本串出现次数的查询
    对文本串建出AC自动机,建出fail树,注意因为左右链上顺序相反文本串要正着插入一次倒着插入一次
    在trie树上dfs,同步跑自动机,进入树上的一个节点时就在对应的dfs序上+1,回溯时-1
    处理询问时直接查子树区间和即可

于是本题就是最后一种情况
复杂度 \(O(nlogn)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define ll long long
#define fir first
#define sec second
#define pb push_back
//#define int long long

int n, m, k;
char c[N], val[N], sta[N], sta2[N];
queue<int> q;
int head[N], nxt[N], size, top, top2, sumlen;
vector<pair<int, int>> ise[N], del[N];
int dep[N], lg[N], fa[25][N], tr[N][26], fail[N], ans[N], tot;
struct edge{int to, next; char val;}e[N<<1];
inline void add(int s, int t, char c) {e[++size].to=t; e[size].next=head[s]; e[size].val=c; head[s]=size;}
struct bit{
	int bit[N];
	inline void upd(int i, int dat) {for (; i<=sumlen; i+=i&-i) bit[i]+=dat;}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
}bit;
struct trie{
	int head[N], in[N], siz[N], size, tot;
	struct edge{int to, next;}e[N<<1];
	trie(){memset(head, -1, sizeof(head));}
	inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}
	void dfs(int u) {
		siz[u]=1;
		in[u]=++tot;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dfs(v);
			siz[u]+=siz[v];
		}
	}
	void upd(int p, int dat) {bit.upd(in[p], dat);}
	int qsum(int p) {return bit.query(in[p]+siz[p]-1)-bit.query(in[p]-1);}
}trie;
void dfs1(int u, int pa) {
	for (int i=1; i<25; ++i)
		if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
		else break;
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==pa) continue;
		val[v]=e[i].val;
		dep[v]=dep[u]+1;
		fa[0][v]=u;
		dfs1(v, u);
	}
}
int lca(int a, int b) {
	if (dep[a]<dep[b]) swap(a, b);
	while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
	if (a==b) return a;
	for (int i=lg[dep[a]]-1; ~i; --i)
		if (fa[i][a]!=fa[i][b])
			a=fa[i][a], b=fa[i][b];
	return fa[0][a];
}
int anc(int a, int t) {
	for (int i=24; ~i; --i) if (t>=1<<i)
		a=fa[i][a], t-=1<<i;
	return a;
}
int kmp() {
	int k=strlen(c+1), ans=0;
	nxt[1]=0;
	for (int i=2,j=0; i<=k; ++i) {
		while (j && c[i]!=c[j+1]) j=nxt[j];
		if (c[i]==c[j+1]) ++j;
		nxt[i]=j;
	}
	for (int i=1,j=0; i<=top; ++i) {
		while (j && (j==k||sta[i]!=c[j+1])) j=nxt[j];
		if (sta[i]==c[j+1]) ++j;
		if (j==k) ++ans;
	}
	return ans;
}
int insert(char* c) {
	// cout<<"ins: "<<endl;
	int u=0, *t;
	for (; *c; ++c,u=*t) {
		t=&tr[u][*c-'a'];
		if (!*t) *t=++tot;
	}
	return u;
}
void build() {
	int u=0;
	for (int i=0; i<26; ++i)
		if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=u, trie.add(u, tr[u][i]);
		else tr[u][i]=u;
	while (q.size()) {
		u=q.front(); q.pop();
		for (int i=0; i<26; ++i)
			if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=tr[fail[u]][i], trie.add(tr[fail[u]][i], tr[u][i]);
			else tr[u][i]=tr[fail[u]][i];
	}
}
void dfs2(int u, int p, int fa) {
	// cout<<"u: "<<u<<' '<<p<<endl;
	trie.upd(p, 1);
	for (auto it:ise[u]) ans[it.sec]+=trie.qsum(it.fir); //, cout<<"add: "<<it.fir<<' '<<it.sec<<' '<<trie.qsum(it.fir)<<endl;
	for (auto it:del[u]) ans[it.sec]-=trie.qsum(it.fir); //, cout<<"del: "<<it.fir<<' '<<it.sec<<' '<<trie.qsum(it.fir)<<endl;
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==fa) continue;
		dfs2(v, tr[p][e[i].val-'a'], u);
	}
	trie.upd(p, -1);
}

signed main()
{
	scanf("%d%d", &n, &m);
	memset(head, -1, sizeof(head));
	for (int i=1,u,v; i<n; ++i) {
		scanf("%d%d%s", &u, &v, c);
		add(u, v, *c); add(v, u, *c);
	}
	for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
	dep[1]=1; dfs1(1, 0);
	for (int i=1,u,v,k,t1,t2,s,t,g; i<=m; ++i) {
		scanf("%d%d%s", &u, &v, c+1);
		// cout<<"uv: "<<u<<' '<<v<<endl;
		k=strlen(c+1);
		t1=insert(c+1); reverse(c+1, c+k+1);
		t2=insert(c+1); reverse(c+1, c+k+1);
		g=lca(u, v);
		s=anc(u, max(dep[u]-dep[g]-(k-1), 0));
		t=anc(v, max(dep[v]-dep[g]-(k-1), 0));
		// cout<<"st: "<<s<<' '<<t<<endl;
		if (u!=g && v!=g) {
			top=top2=0;
			for (int now=s; now!=g; now=fa[0][now]) sta[++top]=val[now];
			for (int now=t; now!=g; now=fa[0][now]) sta2[++top2]=val[now];
			while (top2) sta[++top]=sta2[top2--];
			// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
			ans[i]+=kmp();
			// cout<<"add: "<<kmp()<<endl;
		}
		ise[u].pb({t2, i}); del[s].pb({t2, i});
		ise[v].pb({t1, i}); del[t].pb({t1, i});
	}
	build(); trie.dfs(0); sumlen=trie.tot; dfs2(1, 0, 0);
	for (int i=1; i<=m; ++i) printf("%d\n", ans[i]);

	return 0;
}
posted @ 2021-12-12 19:47  Administrator-09  阅读(3)  评论(0编辑  收藏  举报