后缀树节点数
后缀树节点数
给定字符串 \(s\),每次查询区间 \([l,r]\),询问 \(l,r\) 内的字符串 \(s\) 构成的后缀树的节点数量。
设 \(n\) 为串长,则 \(n\le 10^5,m\le 3\cdot 10^5\)
Solution
我不会 \(\mathcal O(n\log^2 n +m\log n)\) 的算法,所以下面简述 \(\mathcal O(n\log^2 n + m\log^2 n)\) 的做法:
后缀树等价于后缀 Trie 树的虚树,考虑虚树上仅有两类基础节点:
- 所有后缀节点构成的节点,此类数量为 \(r-l+1\)
- 所有虚树中分裂产生的节点。
可能算重,需要减去同时满足两类条件的节点。
对于第二类,使用 LCT + 染色处理,分类节数等价于 \(s+'a',s+'b'\) 存在数,维护最小值和次小值即可,复杂度 \(\mathcal O(n\log^2 n)\)(可以注意到只有第一个被访问的儿子需要更新,这是因为其 lst 会改变)
对于去重的部分,可以证明,若 \(s\) 产生了分裂,则 \(s\) 的后缀产生了分裂,于是可以二分一个后缀并判断其是否分裂即可,这个部分我们仍然使用 LCT 预处理的信息,显然是可以判断的。
LCT 的部分存在较多细节,类似于树点涂色,维护的 lst 需要是上一棵子树中深度最小的节点,调试得让人比较崩。
- 需要维护深度最小的节点。
- 维护子树内被访问过的最小值。
- 一开始没想清楚写自闭了。
复杂度 \(\mathcal O(n\log^2 n+m\log^2 n)\)
我佛了,写了 1h,调了 1.5h
\(Code:\)
#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define mp make_pair
#define pi pair<int, int>
#define pb push_back
#define vi vector<int>
#define ls(x) tr[x].ch[0]
#define rs(x) tr[x].ch[1]
#define se second
#define fi first
int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc < '0' || cc > '9' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
while( cc >= '0' && cc <= '9' ) cn = cn * 10 + cc - '0', cc = getchar() ;
return cn * flus ;
}
const int N = 4e5 + 5 ;
const int inf = 1e9 ;
const int Inf = 1e8 ;
int n, q, m, rem, Idnex[N], Ans[N], bef = 1, cnt = 1 ;
int s[N], fa[N][20], tree[N], dep[N] ;
struct SuffixTree {
map<int, int> ch ;
int lef, len, lk ;
} t[N] ;
struct Q {
int r, id, ans ;
} ; vector<Q> p[N] ;
int Get(int l, int r) {
int u = Idnex[l], len = r - l + 1 ;
drep( i, 0, 17 ) if(dep[fa[u][i]] >= len) u = fa[u][i] ;
return u ;
}
int node(int l, int len) { t[++ cnt].lef = l, t[cnt].len = len, t[cnt].lk = 1 ; return cnt ; }
void insert(int x) {
++ m, ++ rem ; int u = s[x], lst = 1 ;
while(rem) {
while( rem > t[t[bef].ch[s[m - rem + 1]]].len )
rem -= t[bef = t[bef].ch[s[m - rem + 1]]].len ;
int &d = t[bef].ch[s[m - rem + 1]], c = s[t[d].lef + rem - 1] ;
if(!d || u == c) {
t[lst].lk = bef, lst = bef ;
if(!d) d = node(m - rem + 1, inf) ;
else break ;
}
else {
int np = node(t[d].lef, rem - 1) ;
t[np].ch[c] = d, t[np].ch[u] = node(m, inf),
t[d].lef += (rem - 1), t[d].len -= (rem - 1),
t[lst].lk = d = np, lst = np ;
} (bef == 1) ? -- rem : bef = t[bef].lk ;
}
}
struct LCT {
int fa, ch[2], bef, lst, mi, w ;
} tr[N] ;
#define lw(x) (x & (-x))
void add(int x, int k) { for(int i = x; i <= n; i += lw(i)) tree[i] += k ; }
int qry(int x) { int sum = 0 ; for(int i = x; i; i -= lw(i)) sum += tree[i] ; return sum ; }
void pushup(int x) { tr[x].mi = min(min(tr[ls(x)].mi, tr[rs(x)].mi), tr[x].w) ; }
bool isroot(int x) { return (ls(tr[x].fa) != x) && (rs(tr[x].fa) != x) ; }
void rotate(int x) {
int f = tr[x].fa, ff = tr[f].fa, z = ( rs(f) == x ), c = tr[x].ch[z ^ 1] ;
tr[x].fa = ff ; if( !isroot(f) ) tr[ff].ch[rs(ff) == f] = x ;
tr[f].fa = x, tr[x].ch[z ^ 1] = f, tr[c].fa = f, tr[f].ch[z] = c,
pushup(f), pushup(x) ;
}
void Splay(int x) {
while(!isroot(x)) {
int f = tr[x].fa, ff = tr[f].fa ;
if(!isroot(f)) ((rs(f) == x) ^ (rs(ff) == f)) ? rotate(x) : rotate(f) ;
rotate(x) ;
}
}
int Find(int x) { Splay(x) ; return tr[x].mi ; }
int Ls(int x) { while(ls(x)) x = ls(x) ; return x ; }
void access(int u, int l) {
for(int x = u, y = 0; x; y = x, x = tr[x].fa) {
Splay(x) ; if(x != 1 && tr[x].lst) add(dep[x] + Find(tr[x].lst), -1) ;
} int las = u ;
for(int x = u, y = 0; x; y = x, x = tr[x].fa) {
Splay(x) ; int ut = Ls(x) ;
tr[x].lst = tr[x].bef, tr[x].bef = las, las = ut, rs(x) = y, pushup(x) ;
if(x != 1 && tr[x].lst) add(dep[x] + Find(tr[x].lst), 1) ;
}
}
void dfs( int x, int Fa, int l ) {
if( t[x].len >= Inf ) t[x].len = n + 1 - t[x].lef ;
dep[x] = l + t[x].len, fa[x][0] = Fa ; int fl = 0 ;
rep( i, 1, 18 ) fa[x][i] = fa[fa[x][i - 1]][i - 1] ;
for(auto v : t[x].ch)
dfs( v.se, x, l + t[x].len ), fl = 1 ;
if( !fl ) Idnex[n - dep[x] + 1] = x ;
}
bool check(int mid, int L, int R) {
int u = Get(mid, R) ; if(dep[u] != R - mid + 1) return 0 ;
if(t[u].len == 0) u = fa[u][0] ; int v = tr[u].lst ;
return (v && (Find(v) < mid)) ;
}
void solve() {
rep( i, 0, cnt ) {
tr[i].fa = fa[i][0] ;
tr[i].mi = tr[i].w = inf ;
} tr[1].fa = 0 ;
for(int i = n; i >= 1; -- i) {
tr[Idnex[i]].w = i, access(Idnex[i], i) ;
for(auto &x : p[i]) x.ans += qry(x.r) ;
for(auto &x : p[i]) {
int L = i, R = x.r, l = i, r = x.r, ans = r + 1 ;
while(l <= r) {
int mid = (l + r) >> 1 ;
if(check(mid, L, R)) ans = mid, r = mid - 1 ;
else l = mid + 1 ;
} x.ans -= (R - ans + 1) ;
}
}
}
signed main()
{
n = gi(), q = gi(), t[0].len = inf ;
rep( i, 1, n ) s[i] = gi() ; s[++ n] = n ;
rep( i, 1, n ) insert(i) ;
-- n, dfs(1, 1, 0) ;
rep( i, 1, q ) {
int l = gi(), r = gi() ;
p[l].pb((Q){r, i, r - l + 1}) ;
}
solve() ;
rep( i, 1, n ) for(auto x : p[i]) Ans[x.id] = x.ans ;
rep( i, 1, q ) printf("%lld\n", Ans[i] ) ;
return 0 ;
}