[NOI2018][洛谷P4770]你的名字(SAM+SA+主席树)

题面

https://www.luogu.com.cn/problem/P4770

题解

前置知识:

题目多次给定T,询问字符串T中,有多少个不同的子串不与S[l..r]的任何一个子串相等。

首先建出S的后缀自动机,并预处理出fail树上的\(2^k\)代祖先。

对于每次询问,如果我们将T在S的后缀自动机上跑,就可以求出\(T[1..1],T[1..2],…,T[1..|T|]\)在S中的最长匹配长度\(lcp[1],lcp[2],…,lcp[|T|]\),以及该匹配对应的SAM节点\(loc[1],loc[2],…loc[|T|]\)。然后求出T的“前缀数组”,枚举题目所求的字符串的右端点\(cur_r\)

考虑T中,以\(cur_r\)为右端点且不与\(S[l..r]\)的任何一个子串相等的子串有哪些。发现它们的左端点一定是从1开始连续的一段(若\(T[lim..cur_r]\)不合法,那么\(T[lim+1..cur_r],T[lim+2..cur_r]…T[cur_r..cur_r]\)都不符合条件)。先不考虑细节(\(T[1..cur_r]\)\(T[cur_r..cur_r]\)全都合法以及全都不合法),一定存在\(lim\)使得\(T[lim+1..cur_r]\)不合法而\(T[lim..cur_r]\)合法,那么我们可以二分查找lim。

先不谈怎么二分查找,先以\(T[cur_r-lcp[cur_r]+1..cur_r]\)为例,说一下怎么判断合法性。我们知道字符串S的SAM的fail树就是\(S^R\)的后缀树。以下只画出SAM的fail树。假设\(S[1..cur_r]\)在SAM上所能匹配到的最长部分即\(T[cur_r-lcp[cur_r]+1..cur_r]\)是图中橙色部分,蓝色点是\(loc[cur_r]\)

如果\(T[cur_r-lcp[cur_r]+1..curr]\)不合法,假设它在S中的\(S[l_0..r_0](l{\leq}l_0{\leq}r_0{\leq}r,r_0-l_0+1=lcp[cur_r])\)出现了。那么\(S[1..r_0]\)对应的节点(粉色箭头的开头)一定在蓝色点的子树中(因为\(S[1..r_0]\)\(S[l_0..r_0]\)的前缀)。

因此可以看出,\(T[cur_r-lcp+1..curr]\)不合法的充要条件就是,在\(loc\)的子树内存在一个点\(r_0\),使得\(l{\leq}r_0-lcp+1{\leq}r_0{\leq}r\)。移项得到\(l+lcp-1{\leq}r_0{\leq}r\)。这就成了一个在线二维数点问题,可以用主席树解决。

二分查找就基于这一算法。设置两个变量,u和len。初始时,将\(u\)赋值为\(loc[cur_r]\)\(len\)赋值为\(lcp[cur_r]\)。接下来,由于我们已预处理了此树上每一个点的\(2^k\)代祖先,可以倍增地尝试:

for(int i = 20;i >= 0;i--){
	if(fail[u][i] == -1)continue;
	if(legal(fail[u][i],S.len[fail[u][i]]))u = fail[u][i],len = S.len[u];
}
  • 这里的S.len[u]表示的是SAM节点u表示的字符串中最长的那个的长度,建SAM时可以一并算出的。
  • legal是判断是否合法的函数

由此算出的u就满足,我们要找的lim(还记得lim是什么吗www)一定在u到\(fail[u][0]\)的线段上(u含,\(fail[u][0]\)不含)

接下来再进行一次二分查找,这次找的是lim具体在u到\(fail[u][0]\)的线段上的哪个位置。

int L = S.len[fail[u][0]] + 1,R = len;
while(L < R){
	int mid = (L + R) >> 1;
	if(legal(u,mid))R = mid;
	else L = mid + 1;
}

最后\(cur_r-L+1\)就是我们要找的lim啦。

注意要对所有\(1{\leq}cur_r{\leq}len_t\)\(cur_r\)都求一遍lim,所以处理一个询问的时间是\(O(|T| \log |T|)\)

统计答案时还需要注意一个地方:我们求的答案是不能重复的。这和用SA求不同子串个数是一样的,height之后的部分是重复的,不能算进去。

for(int i = 1;i <= len_t;i++){
	int cur_r = sa[i];//其实应该叫pa,因为是前缀数组www
    //ans += lim[cur_r];-------------------------------- wrong!!!
	ans += min(lim[cur_r],curr - h[i]);//--------------- right √	
}

还有一些细节见代码吧。

总时间复杂度\(O(|S| \log |S|+\sum|T| \log |T| )\)

代码

P.S.对于这类\(\sum|T|=1e6\)之类的题目,每次询问千万不能随手memset到底,时间复杂度立刻就不对了

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define ll long long
#define In inline

const int LT = 1e6;
const int LS = 5e5;
const int TN = 2 * LS * 20;

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

In void write(ll x){
	if(x < 0)putchar('-'),x = -x;
	if(x > 9)write(x / 10);
	putchar('0' + x % 10);
}

int ls;

struct CMTree{ //按dfn排序,按编号询问的主席树
	int rt[2*LS+5],c[TN+5][2],num[TN+5];
	int cnt,rn;
	In void pushup(int u){
		num[u] = num[c[u][0]] + num[c[u][1]];
	}
	void ud(int u1,int u2,int l,int r,int x){
		if(l == r){
			num[u2]++;
			return;
		}
		int m = (l + r) >> 1;
		if(x <= m){
			int v = ++cnt;
			num[v] = num[c[u1][0]];
			c[u2][0] = v,c[u2][1] = c[u1][1];
			ud(c[u1][0],c[u2][0],l,m,x);
		}
		else{
			int v = ++cnt;
			num[v] = num[c[u1][1]];
			c[u2][1] = v,c[u2][0] = c[u1][0];
			ud(c[u1][1],c[u2][1],m + 1,r,x);
		}
		pushup(u2);
	}
	In void insert(int x){
		rt[++rn] = ++cnt;
		ud(rt[rn-1],rt[rn],0,ls,x);
	}
	int query(int u1,int u2,int l,int r,int ql,int qr){
		if(l == ql && r == qr)return num[u2] - num[u1];
		int m = (l + r) >> 1;
		if(qr <= m)return query(c[u1][0],c[u2][0],l,m,ql,qr);
		else if(ql > m)return query(c[u1][1],c[u2][1],m + 1,r,ql,qr);
		else return query(c[u1][0],c[u2][0],l,m,ql,m) + query(c[u1][1],c[u2][1],m + 1,r,m + 1,qr);
	}
	In int sum(int dfnl,int dfnr,int ql,int qr){ //查询dfn在[dfnl,dfnr],数值在[ql,qr]中的数u有多少个
		return query(rt[dfnl-1],rt[dfnr],0,ls,ql,qr);
	}
}T;

int loc[LT+5],lcp[LT+5]; //lcp[i]表示t[1..i]在SAM中匹配到的最长长度,loc表示最长匹配对应的SAM节点
char s[LS+5],t[LT+5];
int lt,l,r;
int dfn[2*LS+5],sz[2*LS+5],dn;

struct SAM{
	int nx[2*LS+5][26],fail[2*LS+5][21],len[2*LS+5],flag[2*LS+5];
	int cnt,last;
	void init(){
		fail[0][0] = -1;
	}
	void extend(char c,int n){
		int id = c - 'a';
		int cur = ++cnt,p;
		flag[cur] = n;
		len[cur] = len[last] + 1;
		for(p = last;p != -1 && !nx[p][id];p = fail[p][0])nx[p][id] = cur;
		if(p == -1)fail[cur][0] = 0;
		else{
			int q = nx[p][id];
			if(len[q] == len[p] + 1)fail[cur][0] = q;
			else{
				int clone = ++cnt;
				len[clone] = len[p] + 1;
				fail[clone][0] = fail[q][0];
				memcpy(nx[clone],nx[q],sizeof(nx[clone]));
				fail[cur][0] = fail[q][0] = clone;
				for(;p != -1 && nx[p][id] == q;p = fail[p][0])nx[p][id] = clone;
			}
		}
		last = cur;
	}
	struct edge{
		int next,des;
	}e[4*LS+5];
	int head[2*LS+5],Cnt;
	In void addedge(int a,int b){
		Cnt++;
		e[Cnt].des = b;
		e[Cnt].next = head[a];
		head[a] = Cnt;
	}
	void dfs(int u){
		dfn[u] = ++dn;
		sz[u] = 1;
		T.insert(flag[u]);
		for(rg int i = head[u];i;i = e[i].next){
			int v = e[i].des;
			dfs(v);
			sz[u] += sz[v];
		}
	}
	void build(){ 
		for(rg int i = 1;i <= cnt;i++)addedge(fail[i][0],i);		
		for(rg int j = 1;j <= 20;j++)	
			for(rg int i = 1;i <= cnt;i++)
				if(fail[i][j-1] == -1)fail[i][j] = -1;
				else fail[i][j] = fail[fail[i][j-1]][j-1];
		dfs(0);
	}
	void prepro(){
		for(rg int i = 1,u = 0,l = 0;i <= lt;i++){
			int id = t[i] - 'a';
			while(u && !nx[u][id])u = fail[u][0],l = len[u];
			if(nx[u][id])u = nx[u][id],l++;
			lcp[i] = l,loc[i] = u; 
		}
	}
}S;

int lim[LT+5]; //lim[i]表示以i作为结尾的、最短的合法子串的左端点

struct SA{ //其实是前缀数组
	int sa[LT+5],rk[LT+5],temp[LT+5],num[LT+5],h[LT+5],m;
	void qsort(){
		memset(num,0,sizeof(int) * (m+1)); //要是sizeof(num)就爆掉了,后面的几个也是类似
		for(rg int i = 1;i <= lt;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = lt;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i];
	}
	void calch(){
		int k = 0;
		for(rg int i = lt;i >= 1;i--){
			if(rk[i] == 1)h[rk[i]] = k = 0;
			else{
				if(k)k--;
				int j = sa[rk[i]-1];
				while(t[i-k] == t[j-k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		for(rg int i = 1;i <= lt;i++)rk[i] = t[i] - 'a' + 1;
		for(rg int i = 1;i <= lt;i++)temp[i] = i;
		m = 26;
		qsort();
		for(rg int n = 1;n <= lt;n <<= 1){
			int cnt = 0;
			for(rg int i = 1;i <= n;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= lt;i++)if(sa[i] + n <= lt)temp[++cnt] = sa[i] + n;
			qsort();
			memcpy(temp,rk,sizeof(int) * (lt+1));
			cnt = rk[sa[1]] = 1;
			for(rg int i = 2;i <= lt;i++){
				if(temp[sa[i-1]] != temp[sa[i]] || temp[sa[i-1]-n] != temp[sa[i]-n])cnt++;
				rk[sa[i]] = cnt;
			}
			if(cnt == lt)break;
			m = cnt;
		}
		calch();
	}
	In bool legal(int u,int len){
		if(!len)return 0;
		if(len > r - l + 1)return 1;
		return !T.sum(dfn[u],dfn[u] + sz[u] - 1,l + len - 1,r);
	}
	int calclim(int curr){
		int u = loc[curr],len = lcp[curr];
		if(!len)return curr;
		if(!legal(u,len))return curr - len;
		for(rg int i = 20;i >= 0;i--){
			if(S.fail[u][i] == -1)continue;
			if(legal(S.fail[u][i],S.len[S.fail[u][i]]))u = S.fail[u][i],len = S.len[u];
		}
		int L = S.len[S.fail[u][0]] + 1,R = len;
		while(L < R){
			int mid = (L + R) >> 1;
			if(legal(u,mid))R = mid;
			else L = mid + 1;
		}
		return curr - L + 1;
	}
	void count(){
		for(rg int i = 1;i <= lt;i++)lim[i] = calclim(i);
		ll ans = 0;
		for(rg int i = 1;i <= lt;i++){
			int curr = sa[i];
			ans += (ll)min(lim[curr],curr - h[i]);	
		}
		write(ans);putchar('\n');
		memset(rk,0,sizeof(int) * (lt+1));
		memset(temp,0,sizeof(int) * (lt+1));
		memset(sa,0,sizeof(int) * (lt+1));
	}
}A;

int main(){
	scanf("%s",s + 1);
	ls = strlen(s + 1);
	S.init();
	for(rg int i = 1;i <= ls;i++)S.extend(s[i],i);
	S.build();
	int q = read();
	while(q--){
		scanf("%s",t + 1);
		lt = strlen(t + 1);
		l = read(),r = read();
		S.prepro();
		A.init();
		A.count();
	}
	return 0;
}
posted @ 2020-10-05 18:58  coder66  阅读(266)  评论(0编辑  收藏  举报