[SPOJ11482][BZOJ2787]Count on a trie(广义SA+长链剖分+BIT)

题面

https://darkbzoj.tk/problem/2787

题解

前置知识:

操作2可以当做3的特殊情况,因为2可以看做将一个T中的字符串与一个长度为1的字符串连接起来。初始时将'a'~'z'这26个字符都放入T就可以了。另外,我们把此题中的S,T都颠倒一下,即题目中的1操作变为在\(S_i\)开头加字符,2 0操作变为在\(T_i\)结尾加字符,等等……这是为了方便后续的处理。以及,本题中未保证\(S_i\)的两两不同,对于每一个位置存一个出现次数即可,因此下文默认\(S\)是不可重集。

简化过后,相当于我们有字符串集合S,T,初始时S中只有空串,T中有空串和26个字符。有三种操作需要维护:

  1. 在S的某一个串Si前添加一个字符c,加入S
  2. 将T的两个串Ti,Tj首尾相接形成一个新串TiTj,加入T
  3. 询问T中的某个串Ti在S中某个串Si中的出现次数

所有S中的字符串的开头形成一棵Trie,因此可以通过离线所有的操作1建出这棵Trie,然后通过广义SA对于所有的S进行排序。下定义:

  • \(sa[i]\)表示将所有S中的字符串,第i小的是哪一个。
  • \(rnk[i]\)表示\(S_i\)\(S\)中的大小排名。

这二者均随SA求出。

对于操作2,可以对于每一个T中的字符串\(T_i\),维护\(Tlen[i]\)表示\(T_i\)的长度(这个很好做);以及\(l_i,r_i\)表示\(T_i\)恰是\(S_{sa[l_i]}\)\(S_{sa[r_i]}\)的前缀。其中那么现在关键的问题是怎么求出新串的l和r值。

设由\(T_i+T_j\)形成的新串是\(T_{id}\),一定有\([l_{id},r_{id}] \subseteq [l_i,r_i]\)。因此我们可以在\(l_i\)\(r_i\)之间二分\(l_{id}\)\(r_{id}\),这样就转化为比较\(S_{sa[mid]}\)\(T_{id}\)的大小,也就是\(S_{sa[mid]}-T_i\)\(T_j\)的大小(这里对于字符串A,B,A+B表示拼接,A-B表示从B开头截去A所得字符串)

由于所有的S都在一棵Trie树上,所以\(S_{sa[mid]}-T_i\)其实就是\(S_{sa[mid]的|T_i|代祖先}\)。这里需要一个长链剖分的优化,经过\(O(n \log n)\)的预处理后,能够\(O(1)\)地求出树上一个点的\(k\)代祖先。设\(sa[mid]\)\(|T_i|\)代祖先是p,只需比较\(S_p\)\(T_j\)的大小。

\(T_j\)\(S_{sa[l_j]}\)\(S_{sa[r_j]}\)的前缀。这样只需判断\(rnk[p]\)\(l_j\)\(r_j\)的大小关系即可。具体是\(l_j\)还是\(r_j\)要看当前二分求的是\(l_{id}\)还是\(r_{id}\)

对于操作3,求\(T_i\)\(S_j\)中的出现次数,等价于求j到根路径上,有多少个点的rnk值是\(\in [l_i,r_i]\)的。将询问挂在点j上,最后统一进行DFS计算答案,用一个BIT来维护即可。

总时间复杂度\(O(q \log q)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline

const int N = 3e5;
const int TN = 3e5 + 26;

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

In void write(int x){
	if(x < 0)putchar('-');
	if(x > 9)write(x / 10);
	putchar('0' + x % 10);
}

//main
int loc[N+5],lg[N+5],Sn,Qn,Tn,q,ans[N+5],l[TN+5],r[TN+5],Tlen[TN+5];
//LCD
int fa[N+5][20],len[N+5],son[N+5],top[N+5],dep[N+5]; 
vector<int>up[N+5],down[N+5];
//SA
int rnk[N+5];

struct BIT{
	int b[N+5];
	In int lowbit(int x){
		return x & -x;
	}
	void ud(int x,int dx){
		for(rg int i = x;i <= Sn;i += lowbit(i)){
			b[i] += dx;
		}
	}
	int query(int x){
		int rt = 0;
		for(rg int i = x;i;i -= lowbit(i))rt += b[i];
		return rt;
	}
	int sum(int l,int r){
		return query(r) - query(l - 1);
	}
}B;

struct qnode{
	int l,r,id;
	qnode(){l = r = id = 0;};
	qnode(int _l,int _r,int _id){l = _l,r = _r,id = _id;}
};

struct Trie{ 
	int nx[N+5][26],num[N+5],cnt,w[N+5];
	vector<qnode>que[N+5];
	void init(){
		loc[1] = 0;
		num[0] = 1;
		w[0] = -1;
	}
	int insert(int last,int id){	
		if(!nx[last][id]){
			nx[last][id] = ++cnt;
			fa[cnt][0] = last;
			w[cnt] = id;
		}
		num[nx[last][id]]++;
		return nx[last][id];	
	}
	void dfs(int u){ //统计答案
		if(rnk[u])B.ud(rnk[u],num[u]);
		for(rg int i = 0;i < que[u].size();i++){
			ans[que[u][i].id] = B.sum(que[u][i].l,que[u][i].r);
		}
		for(rg int i = 0;i < 26;i++)if(nx[u][i])dfs(nx[u][i]);
		if(rnk[u])B.ud(rnk[u],-num[u]);
	}
}T;	

namespace LCD{ //长链剖分
	void dfs1(int u){
		dep[u] = dep[fa[u][0]] + 1;
		for(rg int i = 0;i < 26;i++)if(T.nx[u][i]){
			int v = T.nx[u][i];
			dfs1(v);
			if(len[v] > len[u])len[u] = len[v],son[u] = v;
		}
		len[u]++;
	}
	void dfs2(int u,int t){
		top[u] = t;
		down[t].push_back(u);
		if(son[u])dfs2(son[u],t);
		for(rg int i = 0;i < 26;i++)if(T.nx[u][i]){
			int v = T.nx[u][i];
			if(v == son[u])continue;
			dfs2(v,v);
		}
		if(top[u] == u){
			for(rg int i = 0,v = u;i < down[u].size();i++,v = fa[v][0])
				up[u].push_back(v);
		}
	}
	void prepro(){
		for(rg int j = 1;j <= 19;j++)
			for(rg int i = 1;i <= T.cnt;i++)fa[i][j] = fa[fa[i][j-1]][j-1];
		dfs1(0);
		dfs2(0,0);
	}
	In int query(int u,int k){ //O(1)求u的k级祖先
		if(!k)return u;
		int v = top[fa[u][lg[k]]];
		k -= dep[u] - dep[v];
		if(k > 0)return up[v][k];
		else return down[v][-k];
	}
}
using namespace LCD;

struct SA{ 
	int temp[N+5],sa[N+5],rk[N+5][20],num[N+5],h[N+5];
	vector<int>c[N+5];
	int m;
	void qsort(int cur){
		memset(num,0,sizeof(int) * (m+5));
		for(rg int i = 1;i <= T.cnt;i++)num[rk[i][cur]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = T.cnt;i >= 1;i--)sa[num[rk[temp[i]][cur]]--] = temp[i];
	}
	void calch(){
		h[1] = 0;
		for(rg int i = 2;i <= T.cnt;i++){
			int u = sa[i-1],v = sa[i];
			for(rg int j = 19;j >= 0;j--){
				if((1<<j) >= dep[u] || (1<<j) >= dep[v])continue;
				if(rk[u][j] == rk[v][j])
					u = fa[u][j],v = fa[v][j],h[i] += (1<<j);	
			}
		}
	}
	void init(){
		for(rg int i = 1;i <= T.cnt;i++)rk[i][0] = T.w[i] + 1,temp[i] = i;
		m = 26;
		qsort(0);
		for(rg int d = 1,cur = 0;d <= T.cnt;d <<= 1,cur++){
			int cnt = 0;
			for(rg int i = 1;i <= T.cnt;i++)c[i].resize(0);
			for(rg int i = 1;i <= T.cnt;i++)if(dep[i] <= d + 1)temp[++cnt] = i;
			else c[fa[i][cur]].push_back(i);
			for(rg int i = 1;i <= T.cnt;i++){
				for(rg int j = 0;j < c[sa[i]].size();j++)temp[++cnt] = c[sa[i]][j];
			}
			qsort(cur);
			cnt = 1;
			rk[sa[1]][cur+1] = 1;
			for(rg int i = 2;i <= T.cnt;i++){
				if(rk[sa[i]][cur] != rk[sa[i-1]][cur] || rk[fa[sa[i]][cur]][cur] != rk[fa[sa[i-1]][cur]][cur])cnt++;
				rk[sa[i]][cur+1] = cnt;
			}
			if(cnt == T.cnt){
				for(rg int i = 1;i <= T.cnt;i++)rnk[i] = rk[i][cur+1];
			}
			m = cnt;
		}
		calch();
	}
}S;

In int cmp(int x,int k,int y){ //suf_x去掉前k位后,和suf_y比大小;-1为<,0为=,1为>
	if(dep[x] <= k + 1)return -1;
	int z = query(x,k);
	return rnk[z] < rnk[y] ? -1 : rnk[z] > rnk[y];
}

In bool empty(int i){
	return !l[i] && !r[i];
}

void merge(int i,int j,int id){ //T[id]是T[i]+T[j],计算它的l,r
	Tlen[id] = Tlen[i] + Tlen[j];
	if(empty(i))l[id] = l[j],r[id] = r[j];
	else if(empty(j))l[id] = l[i],r[id] = r[i];
	else{
		if(cmp(S.sa[r[i]],Tlen[i],S.sa[l[j]]) < 0)l[id] = r[i] + 1;
		else{
			int L = l[i],R = r[i];
			while(L < R){
				int mid = (L + R) >> 1;
				if(cmp(S.sa[mid],Tlen[i],S.sa[l[j]]) < 0)L = mid + 1;
				else R = mid;
			}
			l[id] = L;
		} 
		if(cmp(S.sa[l[i]],Tlen[i],S.sa[r[j]]) > 0)r[id] = l[i] - 1;
		else{
			int L = l[i],R = r[i];
			while(L < R){
				int mid = (L + R + 1) >> 1;
				if(cmp(S.sa[mid],Tlen[i],S.sa[r[j]]) > 0)R = mid - 1;
				else L = mid;
			}
			r[id] = L;
		}
	}
}

struct inst{
	int opt,x,y;
}I[N+5];

int main(){
	freopen("SP11482.in","r",stdin);
	freopen("SP11482.out","w",stdout);
	for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1;
	q = read();
	Sn = Tn = 1;
	T.init();
	for(rg int i = 1;i <= q;i++){
		int opt = read();
		if(opt <= 2){
			if(opt == 1){
				I[i].opt = 1;
				I[i].x = read();
				I[i].y = getchar() - 'a';
				loc[++Sn] = T.insert(loc[I[i].x],I[i].y);				
			}
			else{
				int dir = read();
				I[i].opt = 2;
				I[i].x = read();
				I[i].y = getchar() - 'a' + N + 2;
				if(dir)swap(I[i].x,I[i].y);
			}
		}
		else{
			if(opt == 3){
				I[i].opt = 2;
				I[i].x = read();
				I[i].y = read();
				swap(I[i].x,I[i].y);
			}
			else{
				I[i].opt = 3;
				I[i].x = read();
				I[i].y = read();
			}
		}
	}
	prepro();
	T.print();
	S.init();
	for(rg int i = 0;i < 26;i++){ //计算字符'a'~'z'的l,r值,存放在l,r[N+2 ~ N+27]中
		Tlen[N+2+i] = 1;
		if(T.w[S.sa[T.cnt]] < i)l[N+2+i] = T.cnt + 1;
		else{
			int L = 1,R = T.cnt;
			while(L < R){
				int mid = (L + R) >> 1;
				if(T.w[S.sa[mid]] < i)L = mid + 1;
				else R = mid; 
			}
			l[N+2+i] = L;
		}
		if(T.w[S.sa[1]] > i)r[N+2+i] = 0;
		else{
			int L = 1,R = T.cnt;
			while(L < R){
				int mid = (L + R + 1) >> 1;
				if(T.w[S.sa[mid]] > i)R = mid - 1;
				else L = mid;
			}
			r[N+2+i] = L;
		}
	}
	for(rg int i = 1;i <= q;i++){
		if(I[i].opt == 1)continue;
		else{
			if(I[i].opt == 2)merge(I[i].x,I[i].y,++Tn);
			else{
				int x = I[i].x,y = I[i].y;
				if(empty(x) || y == 1)ans[++Qn] = 0;
				else{
					T.que[loc[y]].push_back(qnode(l[x],r[x],++Qn)); //离线
				}
			}
		}		
	}	
	T.dfs(0);
	for(rg int i = 1;i <= Qn;i++)write(ans[i]),putchar('\n');
	return 0;
}
posted @ 2020-10-05 20:17  coder66  阅读(246)  评论(0编辑  收藏  举报