BZOJ 1921: [Ctsc2010]珠宝商 点分治套SAM

题目链接


首先可以发现两种算法

一. 暴力处理

对"特征串"建SAM

枚举路径的一个端点,\(dfs\)另一个端点,同时维护在SAM上的位置.

每到一个位置会有SAM上对应节点的right集合大小的贡献

复杂度\(O(size^2)\)

二. 处理经过一个点的所有路径

设这个点是\(u\),字符为\(a[u]\)

需要建出正反特征串的后缀树

考虑从点\(u\)出发的所有路径(\(a[u]\)为字符串的开头),统计出以特征串的每一位为起始的这些串的数量

同理将这些串\(reverse\)(\(a[u]\)为串的最后一位),统计出在特征串每一位结束的串的数量

对应位上两组串数量的乘积的和即为贡献,因为某正串在一位起始,某反串在这位结束,即可拼出一个路径

由于\(dfs\)时正串是每次在末尾加字符维护在特征串中起始位置,反串是每次开头加字符维护结束位置,将特征串和路径串翻转后即为同一个问题,我们只考虑\(push\ front\)维护结束位置

后缀自动机的转移只能支持末尾插入,于是需要利用后缀树

后缀树上一条边会对应原串中的一段区间

转移时需要注意从上往下走了不满一条边的情况,此时大概需要走到儿子处,注意判无转移时无解

每次在节点上打标记,最后全部下放到叶子处 统计每个位置的出现次数

复杂度\(O(size+m)\)


然后使用点分治,若分治的大小\(size>\sqrt{m}\)使用方法2, 否则暴力做方法1

这里需要注意同一子树的去重 在去重的时候应使用对应的方法保证复杂度

易知复杂度为\(O((n+m)\sqrt{m})\)

代码如下

#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long
#define rep(i,x,y) for(int i=(x);i<=(y);++i)
#define travel(i,x) for(int i=h[x];i;i=pre[i])

inline char read() {
	static const int IN_LEN = 1000000;
	static char buf[IN_LEN], *s, *t;
	return (s == t ? t = (s = buf) + fread(buf, 1, IN_LEN, stdin), (s == t ? -1 : *s++) : *s++);
}
template<class T>
inline void read(T &x) {
	static bool iosig;
	static char c;
	for (iosig = false, c = read(); !isdigit(c); c = read()) {
		if (c == '-') iosig = true;
		if (c == -1) return;
	}
	for (x = 0; isdigit(c); c = read()) x = ((x + (x << 2)) << 1) + (c ^ '0');
	if (iosig) x = -x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh = obuf;
inline void print(char c) {
	if (ooh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh = obuf;
	*ooh++ = c;
}
template<class T>
inline void print(T x) {
	static int buf[30], cnt;
	if (x == 0) print('0');
	else {
		if (x < 0) print('-'), x = -x;
		for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
		while (cnt) print((char)buf[cnt--]);
	}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }
const int N = 50005;
int n, m, num, tot1[N], tot2[N], h[N], pre[N<<1], e[N<<1];
bool vis[N];
char a[N], s[N];
inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
struct sam{
	int last, cnt, b[N], str[N], t[N<<1], lazy[N<<1], q[N<<1], siz[N<<1], len[N<<1], fa[N<<1], ch[N<<1][26], son[N<<1][26];
	bool isl[N<<1];
	inline sam(){ last=cnt=1;}
	inline void ins(int c){
		int p=last, np=++cnt;
		last=np, str[len[np]=len[p]+1]=c, t[np]=len[np];
		while(p && !ch[p][c]) ch[p][c]=np, p=fa[p];
		if(!p) fa[np]=1;
		else{
			int q=ch[p][c];
			if(len[q]==len[p]+1) fa[np]=q;
			else{
				int nq=++cnt;
				len[nq]=len[p]+1, memcpy(ch[nq], ch[q], sizeof ch[0]);
				t[nq]=t[q], fa[nq]=fa[q], fa[q]=fa[np]=nq;
				while(ch[p][c]==q) ch[p][c]=nq, p=fa[p];
			}
		}
		siz[np]=isl[np]=1;
	}
	inline void init(){
		rep(i, 1, cnt) ++b[len[i]];
		rep(i, 1, m) b[i]+=b[i-1];
		rep(i, 1, cnt) q[b[len[i]]--]=i;
		for(int i=cnt; i>1; --i) son[fa[q[i]]][str[t[q[i]]-len[fa[q[i]]]]]=q[i], siz[fa[q[i]]]+=siz[q[i]];
	}
	inline void trans(int &p, int c){ p=ch[p][c];}
	void dfs5(int u, int fa, int p, int l){
		if(!p) return;
		if(l==len[p]) p=son[p][a[u]];
		else if(str[t[p]-l]!=a[u]) p=0;
		if(!p) return;
		++lazy[p];
		travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs5(e[i], u, p, l+1);
	}
	inline void work(int *tot){
		rep(i, 2, cnt) lazy[q[i]]+=lazy[fa[q[i]]];
		rep(i, 1, cnt) if(isl[i]) tot[len[i]]=lazy[i];
		memset(lazy, 0, sizeof lazy);
	}
}sam1, sam2;
ll ans;
int root, ctr, Siz, mn, top, lim, siz[N], stk[N];
void dfs1(int u, int fa=0){
	siz[u]=1;
	int mx=0;
	travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs1(e[i], u), siz[u]+=siz[e[i]], mx=max(mx, siz[e[i]]);
	mx=max(mx, Siz-siz[u]);
	if(mx<mn) mn=mx, ctr=u;
}
inline int getctr(int u, int size){ return Siz=mn=size, dfs1(u), ctr;}
void dfs3(int u, int p, int W, int fa=0){
	sam1.trans(p, a[u]);
	if(p){
		ans+=W*sam1.siz[p];
		travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs3(e[i], p, W, u);
	}
}
void dfs2(int u, int fa=0){
	stk[++top]=a[u];

	int p=1;
	for(int i=top; i; --i) sam1.trans(p, stk[i]);
	dfs3(root, p, -1);
	travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs2(e[i], u);
	--top;
}
void dfs4(int u, int fa=0){
	dfs3(u, 1, 1);
	travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs4(e[i], u);
}
void solve(int u, int fa=0, int v=0){
	int size=Siz;
	if(size<=lim){
		if(fa){
			stk[top=1]=a[fa];
			root=v, dfs2(v);
		}
		dfs4(u);
	}
	else{
		if(fa){
			sam1.dfs5(v, fa, sam1.son[1][a[fa]], 1), sam2.dfs5(v, fa, sam2.son[1][a[fa]], 1);
			sam1.work(tot1), sam2.work(tot2);
			rep(i, 1, m) ans-=tot1[i]*tot2[m-i+1];
		}
		sam1.dfs5(u, 0, 1, 0), sam2.dfs5(u, 0, 1, 0);
		sam1.work(tot1), sam2.work(tot2);
		rep(i, 1, m) ans+=tot1[i]*tot2[m-i+1];
		vis[u]=1;
		travel(i, u) if(!vis[e[i]]) solve(getctr(e[i], siz[e[i]]<siz[u]?siz[e[i]]:size-siz[u]), u, e[i]);
	}
}
int main() {
	read(n), read(m);
	lim=sqrt(m);
	rep(i, 2, n){
		static int x, y;
		read(x), read(y);
		add(x, y), add(y, x);
	}
	while(isspace(a[1]=read()));
	rep(i, 2, n) a[i]=read();
	rep(i, 1, n) a[i]-='a';
	while(isspace(s[1]=read()));
	rep(i, 2, m) s[i]=read();
	rep(i, 1, m) sam1.ins(s[i]-='a');
	for(int i=m; i; --i) sam2.ins(s[i]);

	sam1.init(), sam2.init();
	solve(getctr(1, n));
	return printf("%lld", ans), 0;
}
posted @ 2018-07-30 20:59  CMXRYNP  阅读(467)  评论(3编辑  收藏  举报