后缀树节点数

后缀树节点数

给定字符串 \(s\),每次查询区间 \([l,r]\),询问 \(l,r\) 内的字符串 \(s\) 构成的后缀树的节点数量。

\(n\) 为串长,则 \(n\le 10^5,m\le 3\cdot 10^5\)

Solution

我不会 \(\mathcal O(n\log^2 n +m\log n)\) 的算法,所以下面简述 \(\mathcal O(n\log^2 n + m\log^2 n)\) 的做法:

后缀树等价于后缀 Trie 树的虚树,考虑虚树上仅有两类基础节点:

  1. 所有后缀节点构成的节点,此类数量为 \(r-l+1\)
  2. 所有虚树中分裂产生的节点。

可能算重,需要减去同时满足两类条件的节点。

对于第二类,使用 LCT + 染色处理,分类节数等价于 \(s+'a',s+'b'\) 存在数,维护最小值和次小值即可,复杂度 \(\mathcal O(n\log^2 n)\)(可以注意到只有第一个被访问的儿子需要更新,这是因为其 lst 会改变)

对于去重的部分,可以证明,若 \(s\) 产生了分裂,则 \(s\) 的后缀产生了分裂,于是可以二分一个后缀并判断其是否分裂即可,这个部分我们仍然使用 LCT 预处理的信息,显然是可以判断的。

LCT 的部分存在较多细节,类似于树点涂色,维护的 lst 需要是上一棵子树中深度最小的节点,调试得让人比较崩。

  • 需要维护深度最小的节点。
  • 维护子树内被访问过的最小值。
  • 一开始没想清楚写自闭了。

复杂度 \(\mathcal O(n\log^2 n+m\log^2 n)\)

我佛了,写了 1h,调了 1.5h

\(Code:\)

#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define mp make_pair
#define pi pair<int, int>
#define pb push_back
#define vi vector<int>
#define ls(x) tr[x].ch[0]
#define rs(x) tr[x].ch[1]
#define se second
#define fi first
int gi() {
	char cc = getchar() ; int cn = 0, flus = 1 ;
	while( cc < '0' || cc > '9' ) {  if( cc == '-' ) flus = - flus ; cc = getchar() ; }
	while( cc >= '0' && cc <= '9' )  cn = cn * 10 + cc - '0', cc = getchar() ;
	return cn * flus ;
}
const int N = 4e5 + 5 ; 
const int inf = 1e9 ; 
const int Inf = 1e8 ; 
int n, q, m, rem, Idnex[N], Ans[N], bef = 1, cnt = 1 ; 
int s[N], fa[N][20], tree[N], dep[N] ; 
struct SuffixTree {
	map<int, int> ch ; 
	int lef, len, lk ; 
} t[N] ; 
struct Q {
	int r, id, ans ; 
} ; vector<Q> p[N] ;
int Get(int l, int r) {
	int u = Idnex[l], len = r - l + 1 ; 
	drep( i, 0, 17 ) if(dep[fa[u][i]] >= len) u = fa[u][i] ;
	return u ; 
}
int node(int l, int len) { t[++ cnt].lef = l, t[cnt].len = len, t[cnt].lk = 1 ; return cnt ; }
void insert(int x) {
	++ m, ++ rem ; int u = s[x], lst = 1 ;
	while(rem) {
		while( rem > t[t[bef].ch[s[m - rem + 1]]].len )
		rem -= t[bef = t[bef].ch[s[m - rem + 1]]].len ; 
		int &d = t[bef].ch[s[m - rem + 1]], c = s[t[d].lef + rem - 1] ; 
		if(!d || u == c) {
			t[lst].lk = bef, lst = bef ; 
			if(!d) d = node(m - rem + 1, inf) ; 
			else break ; 
		}
		else {
			int np = node(t[d].lef, rem - 1) ; 
			t[np].ch[c] = d, t[np].ch[u] = node(m, inf),
			t[d].lef += (rem - 1), t[d].len -= (rem - 1),
			t[lst].lk = d = np, lst = np ; 
		} (bef == 1) ? -- rem : bef = t[bef].lk ; 
	} 
}
struct LCT {
	int fa, ch[2], bef, lst, mi, w ; 
} tr[N] ;
#define lw(x) (x & (-x))
void add(int x, int k) { for(int i = x; i <= n; i += lw(i)) tree[i] += k ; }
int qry(int x) { int sum = 0 ; for(int i = x; i; i -= lw(i)) sum += tree[i] ; return sum ; }
void pushup(int x) { tr[x].mi = min(min(tr[ls(x)].mi, tr[rs(x)].mi), tr[x].w) ; }
bool isroot(int x) { return (ls(tr[x].fa) != x) && (rs(tr[x].fa) != x) ; }
void rotate(int x) {
	int f = tr[x].fa, ff = tr[f].fa, z = ( rs(f) == x ), c = tr[x].ch[z ^ 1] ; 
	tr[x].fa = ff ; if( !isroot(f) ) tr[ff].ch[rs(ff) == f] = x ; 
	tr[f].fa = x, tr[x].ch[z ^ 1] = f, tr[c].fa = f, tr[f].ch[z] = c,
	pushup(f), pushup(x) ; 
}
void Splay(int x) {
	while(!isroot(x)) {
		int f = tr[x].fa, ff = tr[f].fa ; 
		if(!isroot(f)) ((rs(f) == x) ^ (rs(ff) == f)) ? rotate(x) : rotate(f) ; 
		rotate(x) ; 
	} 
} 
int Find(int x) { Splay(x) ; return tr[x].mi ; } 
int Ls(int x) { while(ls(x)) x = ls(x) ; return x ; }
void access(int u, int l) { 
	for(int x = u, y = 0; x; y = x, x = tr[x].fa) {
		Splay(x) ; if(x != 1 && tr[x].lst) add(dep[x] + Find(tr[x].lst), -1) ;
	} int las = u ; 
	for(int x = u, y = 0; x; y = x, x = tr[x].fa) {
		Splay(x) ; int ut = Ls(x) ; 
		tr[x].lst = tr[x].bef, tr[x].bef = las, las = ut, rs(x) = y, pushup(x) ; 
		if(x != 1 && tr[x].lst) add(dep[x] + Find(tr[x].lst), 1) ; 
	} 
} 
void dfs( int x, int Fa, int l ) {
	if( t[x].len >= Inf ) t[x].len = n + 1 - t[x].lef ; 
	dep[x] = l + t[x].len, fa[x][0] = Fa ; int fl = 0 ; 
	rep( i, 1, 18 ) fa[x][i] = fa[fa[x][i - 1]][i - 1] ; 
	for(auto v : t[x].ch) 
		dfs( v.se, x, l + t[x].len ), fl = 1 ; 
	if( !fl ) Idnex[n - dep[x] + 1] = x ;
} 
bool check(int mid, int L, int R) {
	int u = Get(mid, R) ; if(dep[u] != R - mid + 1) return 0 ; 
	if(t[u].len == 0) u = fa[u][0] ; int v = tr[u].lst ; 
	return (v && (Find(v) < mid)) ;  
}
void solve() {
	rep( i, 0, cnt ) {
		tr[i].fa = fa[i][0] ; 
		tr[i].mi = tr[i].w = inf ;
	} tr[1].fa = 0 ; 
	for(int i = n; i >= 1; -- i) {
		tr[Idnex[i]].w = i, access(Idnex[i], i) ; 
		for(auto &x : p[i]) x.ans += qry(x.r) ; 
		for(auto &x : p[i]) {
			int L = i, R = x.r, l = i, r = x.r, ans = r + 1 ; 
			while(l <= r) {
				int mid = (l + r) >> 1 ; 
				if(check(mid, L, R)) ans = mid, r = mid - 1 ; 
				else l = mid + 1 ; 
			} x.ans -= (R - ans + 1) ; 
		}
	}
}
signed main()
{
	n = gi(), q = gi(), t[0].len = inf ; 
	rep( i, 1, n ) s[i] = gi() ; s[++ n] = n ;
	rep( i, 1, n ) insert(i) ; 
	-- n, dfs(1, 1, 0) ; 
	rep( i, 1, q ) {
		int l = gi(), r = gi() ; 
		p[l].pb((Q){r, i, r - l + 1}) ; 
	}
	solve() ; 
	rep( i, 1, n ) for(auto x : p[i]) Ans[x.id] = x.ans ; 
	rep( i, 1, q ) printf("%lld\n", Ans[i] ) ; 
	return 0 ;
}
posted @ 2021-01-18 11:02  Soulist  阅读(257)  评论(0编辑  收藏  举报