[HEOI/TJOI2016][洛谷P4094]字符串(SA+主席树)

题面

https://www.luogu.com.cn/problem/P4094

题解

前置知识:

题目给出字符串S,每次询问给出a,b,c,d,问S[a..b]的所有子串与S[c..d]的lcp的最大值。

考虑简化题意:S[a..b]的所有子串->S[a..b]的所有后缀,因为子串的右端点往右移肯定不会变差。

首先预处理出后缀数组和height,以及height的ST表。再以0为下界,\(\min(d-c+1, b-a+1)\)为上界二分答案。假设当前答案A成立,需满足\(\exist x \in[a,b-A+1]{\ } s.t.{\ }lcp(suf_x,suf_c) {\geq} A\)。而满足\(lcp(suf_x,suf_c) \geq A\)的x的rank值是连续的一个区间(rank是后缀数组里的那个),可以二分得出这个区间的左右端点l、r。

接下来的任务是:判断\(rank[a],rank[a+1],…,rank[b-A+1]\)这些值里面,有没有\(\in [l,r]\)的。这个就可以使用主席树解决。

总时间复杂度\(O(n \log^2n)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline

const int N = 1e5;
const int TN = 2e6;

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

In void write(int x){
	if(x < 0)putchar('-'),x = -x;
	if(x > 9)write(x / 10);
	putchar('0' + x % 10);
}

int n,m;
char s[N+5];
int lg[N+5];

struct CMTree{
	int rt[N+5],lc[TN+5],rc[TN+5],num[TN+5];
	int cnt,rn;
	In void pushup(int u){
		num[u] = num[lc[u]] + num[rc[u]];
	}
	void ud(int u1,int u2,int l,int r,int d){
		if(l == r){
			num[u2]++;
			return;
		}	
		int m = (l + r) >> 1;
		if(d <= m){
			int v = ++cnt;
			lc[u2] = v,rc[u2] = rc[u1];
			ud(lc[u1],lc[u2],l,m,d);
		}
		else{
			int v = ++cnt;
			rc[u2] = v,lc[u2] = lc[u1];
			ud(rc[u1],rc[u2],m + 1,r,d);
		}
		pushup(u2);
	}
	int query(int u1,int u2,int l,int r,int ql,int qr){
		if(l == ql && r == qr)return num[u2] - num[u1];
		int m = (l + r) >> 1;
		if(qr <= m)return query(lc[u1],lc[u2],l,m,ql,qr);
		else if(ql > m)return query(rc[u1],rc[u2],m + 1,r,ql,qr);
		else return query(lc[u1],lc[u2],l,m,ql,m) + query(rc[u1],rc[u2],m + 1,r,m + 1,qr);
	}
	void insert(int x){
		rt[++rn] = ++cnt;
		ud(rt[rn-1],rt[rn],0,n,x);
	}
	int sum(int a,int b,int ql,int qr){
		return query(rt[a-1],rt[b],0,n,ql,qr);
	}
}T;

struct ST{
	int minn[N+5][18];
	void prepro(int a[]){
		for(rg int i = 1;i <= n;i++)minn[i][0] = a[i];
		for(rg int j = 1;j <= 17;j++){
			for(rg int i = 1;i + (1<<j) - 1 <= n;i++)minn[i][j] = min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]);
		}
	}
	int query(int l,int r){
		int d = lg[r-l+1];
		return min(minn[l][d],minn[r+1-(1<<d)][d]);
	}
};

struct SA{
	int sa[N+5],temp[N+5],rk[N+5],h[N+5],num[N+5];
	int m;
	ST H;
	void qsort(){
		memset(num,0,sizeof(int) * (m+1));
		for(rg int i = 1;i <= n;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = n;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i];
	}
	void calch(){
		int k = 0;
		for(rg int i = 1;i <= n;i++){
			if(rk[i] == 1)h[1] = k = 0;
			else{
				if(k)k--;
				int j = sa[rk[i]-1];
				while(s[i+k] == s[j+k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		m = 26;
		for(rg int i = 1;i <= n;i++)temp[i] = i,rk[i] = s[i] - 'a' + 1;
		qsort();
		for(rg int d = 1;d <= n;d <<= 1){
			int cnt = 0;
			for(rg int i = n - d + 1;i <= n;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= n;i++)if(sa[i] > d)temp[++cnt] = sa[i] - d;
			qsort();
			memcpy(temp,rk,sizeof(int) * (n+1));
			cnt = rk[sa[1]] = 1;
			for(rg int i = 2;i <= n;i++){
				if(temp[sa[i]] != temp[sa[i-1]] || temp[sa[i]+d] != temp[sa[i-1]+d])cnt++;
				rk[sa[i]] = cnt;
			}
			if(cnt == n)break;
			m = cnt;
		}	
		calch();
		H.prepro(h);
		for(rg int i = 1;i <= n;i++)T.insert(rk[i]);
	}
	int lcp(int i,int j){
		if(i == j)return n - i + 1;
		int x = rk[i],y = rk[j];
		if(x > y)swap(x,y);
		return H.query(x + 1,y);
	}
}S;

In bool check(int A,int a,int b,int c){
	if(!A)return 1;
	int L,R;
	L = 1,R = S.rk[c];
	while(L < R){
		int mid = (L + R) >> 1;
		if(S.lcp(S.sa[mid],c) >= A)R = mid;
		else L = mid + 1;
	}
	int l = L;
	L = S.rk[c],R = n;
	while(L < R){
		int mid = (L + R + 1) >> 1;
		if(S.lcp(c,S.sa[mid]) >= A)L = mid;
		else R = mid - 1;
	}
	int r = R;
	return T.sum(a,b - A + 1,l,r);
}

int main(){
	for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1;
	n = read(); m = read();
	scanf("%s",s + 1);
	S.init();
	while(m--){
		int a = read(),b = read(),c = read(),d = read();
		int L = 0,R = min(b - a + 1,d - c + 1);
		while(L < R){
			int mid = (L + R + 1) >> 1;
			if(check(mid,a,b,c))L = mid;
			else R = mid - 1;
		}
		write(L),putchar('\n');
	}
	return 0;
}

[]

posted @ 2020-10-05 20:09  coder66  阅读(170)  评论(0编辑  收藏  举报