P3181 [HAOI2016]找相同字符

思路

广义SAM
把两个字符串建成广义SAM,然后统计两个SAM中相同节点的endpos大小乘积即可
记得开long long

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
const int MAXN = 800100;
int endpos[MAXN][2],trans[MAXN][26],suflink[MAXN],maxlen[MAXN],minlen[MAXN],Nodecnt,in[MAXN],n;
char s[MAXN];
int New_state(int _maxlen,int _minlen,int *_trans,int _suflink){
	++Nodecnt;
	maxlen[Nodecnt]=_maxlen;
	minlen[Nodecnt]=_minlen;
	if(_trans)
		for(int i=0;i<26;i++)
			trans[Nodecnt][i]=_trans[i];
	suflink[Nodecnt]=_suflink;
	return Nodecnt;
}
int add_len(int u,int c,int inq){
	if(trans[u][c]){
		int v=trans[u][c];
		if(maxlen[v]==maxlen[u]+1){
			endpos[v][inq]++;
			return v;
		}
		else{
			int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
			endpos[y][inq]++;
			suflink[v]=y;
			minlen[v]=maxlen[y]+1;
			while(u&&(trans[u][c]==v)){
				trans[u][c]=y;
				u=suflink[u];
			}
			minlen[y]=maxlen[suflink[y]]+1;
			return y;
		}
	}
	else{
		int z=New_state(maxlen[u]+1,0,NULL,0);
		endpos[z][inq]++;
		while(u&&(trans[u][c]==0)){
			trans[u][c]=z;
			u=suflink[u];
		}
		if(!u){
			suflink[z]=1;
			minlen[z]=1;
			return z;
		}
		int v=trans[u][c];
		if(maxlen[v]==maxlen[u]+1){
			suflink[z]=v;
			minlen[z]=maxlen[v]+1;
			return z;
		}
		int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
		suflink[v]=suflink[z]=y;
		minlen[v]=minlen[z]=maxlen[y]+1;
		while(u&&(trans[u][c]==v)){
			trans[u][c]=y;
			u=suflink[u];
		}
		minlen[y]=maxlen[suflink[y]]+1;
		return z;
	}
}
queue<int> q;
void get_sz(void){
	for(int i=2;i<=Nodecnt;i++)
		in[suflink[i]]++;
	for(int i=0;i<=Nodecnt;i++)
		if(!in[i])
			q.push(i);
	while(!q.empty()){
		int x=q.front();
		q.pop();
		endpos[suflink[x]][0]+=endpos[x][0];
		endpos[suflink[x]][1]+=endpos[x][1];
		in[suflink[x]]--;
		if(!in[suflink[x]])
			q.push(suflink[x]);
	}
}
long long ans=0;
int main(){
	scanf("%s",s+1);
	n=strlen(s+1);
	Nodecnt=1;
	int last=1;
	for(int i=1;i<=n;i++)
		last=add_len(last,s[i]-'a',0);
	scanf("%s",s+1);
	n=strlen(s+1);
	last=1;
	for(int i=1;i<=n;i++)
		last=add_len(last,s[i]-'a',1);
	get_sz();
	for(int i=2;i<=Nodecnt;i++){
		ans+=(long long)((long long)maxlen[i]-minlen[i]+1)*(long long)((long long)endpos[i][0]*endpos[i][1]);
	}
	printf("%lld\n",ans);
	return 0;
}
posted @ 2019-04-01 19:11  dreagonm  阅读(167)  评论(0编辑  收藏  举报