【HAOI2016】 找相同字符-后缀数组+单调栈

【HAOI2016】 找相同字符

子串之类的问题,容易想到后缀数组。最后的问题是,在 \(A,B\) 中各找一个子串,有多少种情况这两个子串相同。

我们分别对 \(A\),\(B\),\(A+B\) 三个串求其后缀数组和 \(height\) 数组。我们可以求出在各自串中相同的子串数量,然后用容斥原理算出答案即可。

那么如何求各自串中相同的子串数量呢?容易想到单调栈。我们令 \(f_i\)\([sa_i,\cdots,n]\) 子串与其后缀有多少个前缀相等。结合单调栈,容易算出答案。

//Don't act like a loser.
//This code is written by huayucaiji
//You can only use the code for studying or finding mistakes
//Or,you'll be punished by Sakyamuni!!!
//#pragma GCC optimize("Ofast","-funroll-loops","-fdelete-null-pointer-checks")
//#pragma GCC target("ssse3","sse3","sse2","sse","avx2","avx")
#include<bits/stdc++.h>
#define int long long
using namespace std;

int read() {
	char ch=getchar();
	int f=1,x=0;
	while(ch<'0'||ch>'9') {
		if(ch=='-')
			f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9') {
		x=x*10+ch-'0';
		ch=getchar();
	}
	return f*x;
}

const int MAXN=4e5+10;

int sa[MAXN],rnk[MAXN],height[MAXN],cnt[MAXN],f[MAXN];
char a[MAXN>>1],b[MAXN>>1],ab[MAXN];
priority_queue<pair<int,int> > q;

struct strpr {
	int x,y,id;
	
	bool operator <(const strpr q)const {
		if(x!=q.x) {
			return x<q.x;
		}
		return y<q.y;
	}
}p[MAXN];

void combine() {
	int n=strlen(a+1);
	int m=strlen(b+1);
	for(int i=1;i<=n;i++) {
		ab[i]=a[i];
	}
	ab[n+1]='$';
	for(int i=1;i<=m;i++) {
		ab[i+n+1]=b[i];
	}
}
void get_sa(char s[]) {
	int n=strlen(s+1);
	for(int i=1;i<=n;i++) {
		rnk[i]=s[i];
	}
	for(int l=1;l<n;l<<=1) {
		for(int i=1;i<=n;i++) {
			p[i].x=rnk[i];
			p[i].y=(i+l<=n? rnk[i+l]:0);
			p[i].id=i;
		}
		sort(p+1,p+n+1);
		int m=0;
		for(int i=1;i<=n;i++) {
			if(p[i].x!=p[i-1].x||p[i].y!=p[i-1].y) {
				m++;
			}
			rnk[p[i].id]=m;
		}
	}
	for(int i=1;i<=n;i++) {
		sa[rnk[i]]=i;
	}
	return ;
}
void get_height(char s[]) {
	int h=0;
	int n=strlen(s+1);
	for(int i=1;i<=n;i++) {
		if(h) {
			h--;
		}
		if(rnk[i]==1) {
			continue;
		}
		
		int p=i+h;
		int q=sa[rnk[i]-1]+h;
		
		while(p<=n&&q<=n&&s[p]==s[q]) {
			p++;
			q++;
			h++;
		}
		height[rnk[i]]=h;
	}
}

int calc(char s[]) {
	int n=strlen(s+1);
	fill(sa,sa+n+1,0);
	fill(rnk,rnk+n+1,0);
	fill(height,height+n+1,0);
	fill(f,f+n+1,0);
	
	get_sa(s);
	get_height(s);
	
	stack<int> stk;
	height[n+1]=0;
	int sum=0;
	for(int i=2;i<=n;++i)
	{
		while(stk.size()&&height[stk.top()]>=height[i])
			stk.pop();
		if(stk.empty())
			f[i]=height[i]*(i-1);
		else 
			f[i]=f[stk.top()]+height[i]*(i-stk.top());
		stk.push(i);
		sum+=f[i];
	}
	return sum;
}

signed main() {
	//freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
	
	scanf("%s%s",a+1,b+1);
	combine();
	
	printf("%lld\n",calc(ab)-calc(a)-calc(b));

	//fclose(stdin);
	//fclose(stdout);
	return 0;
}
/*
abba
bbba
12
*/
posted @ 2021-02-04 17:34  huayucaiji  阅读(48)  评论(0编辑  收藏  举报