[SPOJ11482][BZOJ2787]Count on a trie(广义SA+长链剖分+BIT)
题面
https://darkbzoj.tk/problem/2787
题解
前置知识:
操作2可以当做3的特殊情况,因为2可以看做将一个T中的字符串与一个长度为1的字符串连接起来。初始时将'a'~'z'这26个字符都放入T就可以了。另外,我们把此题中的S,T都颠倒一下,即题目中的1操作变为在\(S_i\)开头加字符,2 0操作变为在\(T_i\)结尾加字符,等等……这是为了方便后续的处理。以及,本题中未保证\(S_i\)的两两不同,对于每一个位置存一个出现次数即可,因此下文默认\(S\)是不可重集。
简化过后,相当于我们有字符串集合S,T,初始时S中只有空串,T中有空串和26个字符。有三种操作需要维护:
- 在S的某一个串Si前添加一个字符c,加入S
- 将T的两个串Ti,Tj首尾相接形成一个新串TiTj,加入T
- 询问T中的某个串Ti在S中某个串Si中的出现次数
所有S中的字符串的开头形成一棵Trie,因此可以通过离线所有的操作1建出这棵Trie,然后通过广义SA对于所有的S进行排序。下定义:
- \(sa[i]\)表示将所有S中的字符串,第i小的是哪一个。
- \(rnk[i]\)表示\(S_i\)在\(S\)中的大小排名。
这二者均随SA求出。
对于操作2,可以对于每一个T中的字符串\(T_i\),维护\(Tlen[i]\)表示\(T_i\)的长度(这个很好做);以及\(l_i,r_i\)表示\(T_i\)恰是\(S_{sa[l_i]}\)到\(S_{sa[r_i]}\)的前缀。其中那么现在关键的问题是怎么求出新串的l和r值。
设由\(T_i+T_j\)形成的新串是\(T_{id}\),一定有\([l_{id},r_{id}] \subseteq [l_i,r_i]\)。因此我们可以在\(l_i\)和\(r_i\)之间二分\(l_{id}\)和\(r_{id}\),这样就转化为比较\(S_{sa[mid]}\)和\(T_{id}\)的大小,也就是\(S_{sa[mid]}-T_i\)和\(T_j\)的大小(这里对于字符串A,B,A+B表示拼接,A-B表示从B开头截去A所得字符串)
由于所有的S都在一棵Trie树上,所以\(S_{sa[mid]}-T_i\)其实就是\(S_{sa[mid]的|T_i|代祖先}\)。这里需要一个长链剖分的优化,经过\(O(n \log n)\)的预处理后,能够\(O(1)\)地求出树上一个点的\(k\)代祖先。设\(sa[mid]\)的\(|T_i|\)代祖先是p,只需比较\(S_p\)和\(T_j\)的大小。
而\(T_j\)是\(S_{sa[l_j]}\)到\(S_{sa[r_j]}\)的前缀。这样只需判断\(rnk[p]\)与\(l_j\)或\(r_j\)的大小关系即可。具体是\(l_j\)还是\(r_j\)要看当前二分求的是\(l_{id}\)还是\(r_{id}\)。
对于操作3,求\(T_i\)在\(S_j\)中的出现次数,等价于求j到根路径上,有多少个点的rnk值是\(\in [l_i,r_i]\)的。将询问挂在点j上,最后统一进行DFS计算答案,用一个BIT来维护即可。
总时间复杂度\(O(q \log q)\)。
代码
#include<bits/stdc++.h>
using namespace std;
#define rg register
#define In inline
const int N = 3e5;
const int TN = 3e5 + 26;
In int read(){
int s = 0,ww = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
return s * ww;
}
In void write(int x){
if(x < 0)putchar('-');
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
//main
int loc[N+5],lg[N+5],Sn,Qn,Tn,q,ans[N+5],l[TN+5],r[TN+5],Tlen[TN+5];
//LCD
int fa[N+5][20],len[N+5],son[N+5],top[N+5],dep[N+5];
vector<int>up[N+5],down[N+5];
//SA
int rnk[N+5];
struct BIT{
int b[N+5];
In int lowbit(int x){
return x & -x;
}
void ud(int x,int dx){
for(rg int i = x;i <= Sn;i += lowbit(i)){
b[i] += dx;
}
}
int query(int x){
int rt = 0;
for(rg int i = x;i;i -= lowbit(i))rt += b[i];
return rt;
}
int sum(int l,int r){
return query(r) - query(l - 1);
}
}B;
struct qnode{
int l,r,id;
qnode(){l = r = id = 0;};
qnode(int _l,int _r,int _id){l = _l,r = _r,id = _id;}
};
struct Trie{
int nx[N+5][26],num[N+5],cnt,w[N+5];
vector<qnode>que[N+5];
void init(){
loc[1] = 0;
num[0] = 1;
w[0] = -1;
}
int insert(int last,int id){
if(!nx[last][id]){
nx[last][id] = ++cnt;
fa[cnt][0] = last;
w[cnt] = id;
}
num[nx[last][id]]++;
return nx[last][id];
}
void dfs(int u){ //统计答案
if(rnk[u])B.ud(rnk[u],num[u]);
for(rg int i = 0;i < que[u].size();i++){
ans[que[u][i].id] = B.sum(que[u][i].l,que[u][i].r);
}
for(rg int i = 0;i < 26;i++)if(nx[u][i])dfs(nx[u][i]);
if(rnk[u])B.ud(rnk[u],-num[u]);
}
}T;
namespace LCD{ //长链剖分
void dfs1(int u){
dep[u] = dep[fa[u][0]] + 1;
for(rg int i = 0;i < 26;i++)if(T.nx[u][i]){
int v = T.nx[u][i];
dfs1(v);
if(len[v] > len[u])len[u] = len[v],son[u] = v;
}
len[u]++;
}
void dfs2(int u,int t){
top[u] = t;
down[t].push_back(u);
if(son[u])dfs2(son[u],t);
for(rg int i = 0;i < 26;i++)if(T.nx[u][i]){
int v = T.nx[u][i];
if(v == son[u])continue;
dfs2(v,v);
}
if(top[u] == u){
for(rg int i = 0,v = u;i < down[u].size();i++,v = fa[v][0])
up[u].push_back(v);
}
}
void prepro(){
for(rg int j = 1;j <= 19;j++)
for(rg int i = 1;i <= T.cnt;i++)fa[i][j] = fa[fa[i][j-1]][j-1];
dfs1(0);
dfs2(0,0);
}
In int query(int u,int k){ //O(1)求u的k级祖先
if(!k)return u;
int v = top[fa[u][lg[k]]];
k -= dep[u] - dep[v];
if(k > 0)return up[v][k];
else return down[v][-k];
}
}
using namespace LCD;
struct SA{
int temp[N+5],sa[N+5],rk[N+5][20],num[N+5],h[N+5];
vector<int>c[N+5];
int m;
void qsort(int cur){
memset(num,0,sizeof(int) * (m+5));
for(rg int i = 1;i <= T.cnt;i++)num[rk[i][cur]]++;
for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
for(rg int i = T.cnt;i >= 1;i--)sa[num[rk[temp[i]][cur]]--] = temp[i];
}
void calch(){
h[1] = 0;
for(rg int i = 2;i <= T.cnt;i++){
int u = sa[i-1],v = sa[i];
for(rg int j = 19;j >= 0;j--){
if((1<<j) >= dep[u] || (1<<j) >= dep[v])continue;
if(rk[u][j] == rk[v][j])
u = fa[u][j],v = fa[v][j],h[i] += (1<<j);
}
}
}
void init(){
for(rg int i = 1;i <= T.cnt;i++)rk[i][0] = T.w[i] + 1,temp[i] = i;
m = 26;
qsort(0);
for(rg int d = 1,cur = 0;d <= T.cnt;d <<= 1,cur++){
int cnt = 0;
for(rg int i = 1;i <= T.cnt;i++)c[i].resize(0);
for(rg int i = 1;i <= T.cnt;i++)if(dep[i] <= d + 1)temp[++cnt] = i;
else c[fa[i][cur]].push_back(i);
for(rg int i = 1;i <= T.cnt;i++){
for(rg int j = 0;j < c[sa[i]].size();j++)temp[++cnt] = c[sa[i]][j];
}
qsort(cur);
cnt = 1;
rk[sa[1]][cur+1] = 1;
for(rg int i = 2;i <= T.cnt;i++){
if(rk[sa[i]][cur] != rk[sa[i-1]][cur] || rk[fa[sa[i]][cur]][cur] != rk[fa[sa[i-1]][cur]][cur])cnt++;
rk[sa[i]][cur+1] = cnt;
}
if(cnt == T.cnt){
for(rg int i = 1;i <= T.cnt;i++)rnk[i] = rk[i][cur+1];
}
m = cnt;
}
calch();
}
}S;
In int cmp(int x,int k,int y){ //suf_x去掉前k位后,和suf_y比大小;-1为<,0为=,1为>
if(dep[x] <= k + 1)return -1;
int z = query(x,k);
return rnk[z] < rnk[y] ? -1 : rnk[z] > rnk[y];
}
In bool empty(int i){
return !l[i] && !r[i];
}
void merge(int i,int j,int id){ //T[id]是T[i]+T[j],计算它的l,r
Tlen[id] = Tlen[i] + Tlen[j];
if(empty(i))l[id] = l[j],r[id] = r[j];
else if(empty(j))l[id] = l[i],r[id] = r[i];
else{
if(cmp(S.sa[r[i]],Tlen[i],S.sa[l[j]]) < 0)l[id] = r[i] + 1;
else{
int L = l[i],R = r[i];
while(L < R){
int mid = (L + R) >> 1;
if(cmp(S.sa[mid],Tlen[i],S.sa[l[j]]) < 0)L = mid + 1;
else R = mid;
}
l[id] = L;
}
if(cmp(S.sa[l[i]],Tlen[i],S.sa[r[j]]) > 0)r[id] = l[i] - 1;
else{
int L = l[i],R = r[i];
while(L < R){
int mid = (L + R + 1) >> 1;
if(cmp(S.sa[mid],Tlen[i],S.sa[r[j]]) > 0)R = mid - 1;
else L = mid;
}
r[id] = L;
}
}
}
struct inst{
int opt,x,y;
}I[N+5];
int main(){
freopen("SP11482.in","r",stdin);
freopen("SP11482.out","w",stdout);
for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1;
q = read();
Sn = Tn = 1;
T.init();
for(rg int i = 1;i <= q;i++){
int opt = read();
if(opt <= 2){
if(opt == 1){
I[i].opt = 1;
I[i].x = read();
I[i].y = getchar() - 'a';
loc[++Sn] = T.insert(loc[I[i].x],I[i].y);
}
else{
int dir = read();
I[i].opt = 2;
I[i].x = read();
I[i].y = getchar() - 'a' + N + 2;
if(dir)swap(I[i].x,I[i].y);
}
}
else{
if(opt == 3){
I[i].opt = 2;
I[i].x = read();
I[i].y = read();
swap(I[i].x,I[i].y);
}
else{
I[i].opt = 3;
I[i].x = read();
I[i].y = read();
}
}
}
prepro();
T.print();
S.init();
for(rg int i = 0;i < 26;i++){ //计算字符'a'~'z'的l,r值,存放在l,r[N+2 ~ N+27]中
Tlen[N+2+i] = 1;
if(T.w[S.sa[T.cnt]] < i)l[N+2+i] = T.cnt + 1;
else{
int L = 1,R = T.cnt;
while(L < R){
int mid = (L + R) >> 1;
if(T.w[S.sa[mid]] < i)L = mid + 1;
else R = mid;
}
l[N+2+i] = L;
}
if(T.w[S.sa[1]] > i)r[N+2+i] = 0;
else{
int L = 1,R = T.cnt;
while(L < R){
int mid = (L + R + 1) >> 1;
if(T.w[S.sa[mid]] > i)R = mid - 1;
else L = mid;
}
r[N+2+i] = L;
}
}
for(rg int i = 1;i <= q;i++){
if(I[i].opt == 1)continue;
else{
if(I[i].opt == 2)merge(I[i].x,I[i].y,++Tn);
else{
int x = I[i].x,y = I[i].y;
if(empty(x) || y == 1)ans[++Qn] = 0;
else{
T.que[loc[y]].push_back(qnode(l[x],r[x],++Qn)); //离线
}
}
}
}
T.dfs(0);
for(rg int i = 1;i <= Qn;i++)write(ans[i]),putchar('\n');
return 0;
}