NOI2018 你的名字(SAM + 可持久化线段树合并)

题目链接: https://www.luogu.com.cn/problem/P4770

SAM好题.

(I)首先我们考虑l = 1,r = |S|的情况怎么做

我们要求的是本质不同的子串str的数量,满足str是T的子串,且str不是$S_{l,r}$的子串

容易用补集转化成T本质不同的子串数减去S和T本质不同子串数

第一个问题很平凡,我们考虑第二个问题

我们对S,T分别建自动机,令T在S上面跑匹配,同时按着S的跑法在T自己上面跑匹配(因为T的每个子串都为$SAM_{T}$所接受,所以一定能跑)

对于每个前缀我们都可以求出它和S的最长公共后缀l,及在T上的节点,容易发现这个节点以上的长度<=l的都是本质不同的公共子串,因为可能算重所以先打标记然后Treedp统计(这也是为什么要在T上跑的原因,

因为在S上面跑,每次都要遍历S的parent tree时间复杂度不对)

(II)接下来才是难点,如果l,r任意怎么做

显然对于每个子串都建后缀自动机是不可能的,我们思考我们这个后缀自动机到底干了什么呢?

1.判断有没有tran(p,c)的转移边.

2.判断p这个节点的maxlen和minlen

我们可以发现,只要用线段树合并维护出endpos集合,就可以完成区间的上诉两个问题.

		int u = get(sam[p].ch[c],l + len,r);
		if(u){
			len++;
			p = sam[p].ch[c];
			x = sam[x].ch[c];
		}
		else{
			while(len != -1 && !get(sam[p].ch[c],l + len,r)){
				len--;
				if(len == sam[sam[p].fa].len)	p = sam[p].fa;
			}

  其中get(p,l,r)表示p这个节点的endpos集合在[l,r]范围内的最大值

设正在匹配的最长公共子串为s

我们发现我们原本要做的事情是判断s在p这个节点上能不能添上'c'这个字符,即判断 if(sam[p].ch[c] != 0),但是因为有区间限制我们应判断是否存在一个位置x可以接上s+'c',即在[l,r]区间内,是否存在一个endpos(x)满足x - len(s+'c')  + 1>= l,即x >= l + len(s+'c') - 1也即x >= len(s) + l,于是只要判断[l+len,r]区间内是否存在endpos集合的元素即可

注意我们若失配此时不应该直接跳fa,而应该先让len自减,要记住这个后缀自动机只是一个框架,是$S_{1,n}$而不是$S_{l,r}$的SAM.

有人可能会问:怎么暴力while怎么能过?

因为数据水?  其实这个时间复杂度是正确的,我们考虑势能分析法,容易发现每次while,len最多减少1,外面for循环每次最多增加1,所以单次匹配时间复杂度是O(|T|logn)的

有很多细节,看代码吧

/*NOI2018[你的名字]*/
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int read(){
	char c = getchar();
	int x = 0;
	while(c < '0' || c > '9')		c = getchar();
	while(c >= '0' && c <= '9')		x = x * 10 + c - 48,c = getchar();
	return x;
}
const int N = 2e6 + 10;
struct SegmentTree{
	int lc,rc;
	int mx;
}t[N<<4];/*线段树维护endpos集合*/
int Rt[N],num,n;
void pushup(int p){
	t[p].mx = max(t[t[p].lc].mx,t[t[p].rc].mx);
}
void Insert(int &p,int l,int r,int pos){
	if(!p)	p = ++num;
	if(l == r){
		t[p].mx = max(t[p].mx,pos);
		return;
	}
	int mid = (l + r) >> 1;
	if(pos <= mid)	Insert(t[p].lc,l,mid,pos);
	else	Insert(t[p].rc,mid+1,r,pos);
	pushup(p);
}
int merge(int p,int q,int l,int r){
	if(!p || !q)	return p | q;
	int u = ++num;
	int mid = (l + r) >> 1;
	t[u].lc = merge(t[p].lc,t[q].lc,l,mid);
	t[u].rc = merge(t[p].rc,t[q].rc,mid + 1,r);	
	pushup(u);
	return u;
}
int query(int p,int l,int r,int a,int b){
	if(a <= l && b >= r)	return t[p].mx;
	int mid = (l + r) >> 1;
	int ans = 0;
	if(a <= mid)	ans = max(ans,query(t[p].lc,l,mid,a,b));
	if(b > mid)		ans = max(ans,query(t[p].rc,mid+1,r,a,b));
	return ans;
}
struct SAM{
	int ch[26],len,fa;
}sam[N<<1];
int lst = 1,cnt = 1;
void ins(int c,int rt){
	int p = lst,np = ++cnt;lst = np;
	sam[np].len = sam[p].len + 1;
	for(; !sam[p].ch[c]; p = sam[p].fa)		sam[p].ch[c] = np;
	if(!p)	sam[np].fa = rt;
	else{
		int q = sam[p].ch[c];
		if(sam[q].len == sam[p].len + 1)	sam[np].fa = q;
		else{
			int nq = ++cnt;
			sam[nq] = sam[q];
			sam[nq].len = sam[p].len + 1;
			sam[np].fa = sam[q].fa = nq;
			for(; sam[p].ch[c] == q; p = sam[p].fa)		sam[p].ch[c] = nq;
		}
	}
}
int head[N<<1];
int f[N<<1],tot;
struct Edge{
	int nxt,point;
}edge[N<<1];
void add_edge(int u,int v){
	edge[++tot].nxt = head[u];
	edge[tot].point = v;
	head[u] = tot;
}
char S[N],T[N];
void dfs(int u){
	for(int i = head[u]; i ; i = edge[i].nxt){
		int v = edge[i].point;
		dfs(v);
		f[u] = max(f[u],f[v]);
	}
	f[u] = min(f[u],sam[u].len);
}
void getpos(int u){
	for(int i = head[u]; i ; i = edge[i].nxt){
		int v = edge[i].point;
		getpos(v);
		Rt[u] = merge(Rt[u],Rt[v],1,n);
	}
}
bool valid(int u,int len){
	return len >= sam[sam[u].fa].len + 1 && len <= sam[u].len;
}
int get(int u,int l,int r){
	if(l > r || !u)	return 0;
	return query(Rt[u],1,n,l,r);
}
int getlen(int u,int l,int r){
	int x = get(u,l,r);
	return min(sam[u].len,x - l + 1);
} 
ll work(char *s,int rt,int l,int r){
	int m = strlen(s+1);
	int p = 1,len = 0,x = rt;
	for(int i = rt + 1; i <= cnt; ++i){
		add_edge(sam[i].fa,i);
	}
	for(int i = 1; i <= m; ++i){
		int c = s[i] - 'a';
		int u = get(sam[p].ch[c],l + len,r);
		if(u){
			len++;
			p = sam[p].ch[c];
			x = sam[x].ch[c];
		}
		else{
			while(len != -1 && !get(sam[p].ch[c],l + len,r)){
				len--;
				if(len == sam[sam[p].fa].len)	p = sam[p].fa;
			}
			if(len == -1){
				p = 1;
				len = 0;
				x = rt;
			}
			else{
				len++;
				p = sam[p].ch[c];	
				while((!sam[x].ch[c] || !valid(sam[x].ch[c],len)) && x)		x = sam[x].fa;
				if(!x)	x = rt;
				x = sam[x].ch[c];
			}
		}
//		cout<<i<<' '<<len<<endl;
		f[x] = max(f[x],len);
	}
	dfs(rt);
	ll ans = 0;
	for(int i = rt + 1; i <= cnt; ++i){/*!!!attention*/
		if(f[i] > sam[sam[i].fa].len){
//			assert(f[i] > sam[sam[i].fa].len);
			ans += f[i] - sam[sam[i].fa].len;
		}
	}
	for(int i = rt; i <= cnt; ++i)		f[i] = 0;
	return ans;
}
int main(){
	freopen("name.in","r",stdin);
	freopen("name.out","w",stdout);
	scanf("%s",S+1);
	n = strlen(S+1);
	for(int i = 1; i <= n; ++i){
		ins(S[i]-'a',1);
		Insert(Rt[lst],1,n,i);
	}
	for(int i = 2; i <= cnt; ++i){
		add_edge(sam[i].fa,i);
	}
	getpos(1);	
	int q = read();
	while(q--){
		scanf("%s",T+1);
		int l = read(),r = read();
		int m = strlen(T+1);
		int rt = ++cnt;
		lst = rt;
		for(int i = 1; i <= m; ++i){
			ins(T[i]-'a',rt);
		}
		ll ans = 0;
		for(int i = rt + 1; i <= cnt; ++i){
			ans += sam[i].len - sam[sam[i].fa].len;
		}
		ans -= work(T,rt,l,r);
		printf("%lld\n",ans);
	}
	return 0;
}

  

 

posted @ 2021-01-07 09:33  y_dove  阅读(194)  评论(0编辑  收藏  举报