【bzoj3796】Mushroom追妹纸 Kmp+二分+Hash

题目描述

给出字符串s1、s2、s3,找出一个字符串w,满足:
1、w是s1的子串;
2、w是s2的子串;
3、s3不是w的子串。
4、w的长度应尽可能大
求w的最大长度。

输入

输入有三行,第一行为一个字符串s1第二行为一个字符串s2, 
第三行为一个字符串s3。输入仅含小写字母,字符中间不含空格。

输出

输出仅有一行,为w的最大可能长度,如w不存在,则输出0。

样例输入

abcdef
abcf
bc

样例输出

2


题解

Kmp+二分+Hash

先使用Kmp处理出s3在s1、s2中出现的所有位置,那么w的选择不能包含这些位置。

然后答案显然满足二分性质,因此二分答案,判断是否有s1和s2的公共长度为mid的子串。

将s1的所有长度为mid且不包含s3的子串的Hash值处理出来,放到哈希表中,然后将s2的所有长度为mid且不包含s3的子串的Hash值放到哈希表里查询即可。

其中判断是否包含s3的子串可以使用前缀后缀和:对于当前的[l,r],如果不合法,相当于在r前面出现过的右端点加上l后面出现过的左端点大于总数目。

Hash的过程可以直接使用自然溢出。

时间复杂度 $O(n\log n)$

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 50010
#define M 30000000
using namespace std;
typedef unsigned long long ull;
ull base[N];
int n[3] , next[N] , sa[2][N] , sb[2][N];
char s[3][N];
struct data
{
	int head[M] , next[N] , tot;
	ull v[N];
	data() {tot = 0;}
	inline void insert(ull x)
	{
		if(!head[x % M]) head[x % M] = ++tot;
		else
		{
			int i;
			for(i = head[x % M] ; next[i] ; i = next[i]);
			next[i] = ++tot;
		}
		v[tot] = x;
	}
	inline bool count(ull x)
	{
		int i;
		for(i = head[x % M] ; i ; i = next[i])
			if(v[i] == x)
				return 1;
		return 0;
	}
	inline void clear()
	{
		int i;
		for(i = 1 ; i <= tot ; i ++ ) v[i] = next[i] = head[v[i] % M] = 0;
		tot = 0;
	}
}mp;
void kmp(int p)
{
	int i , j;
	for(i = j = 0 ; i < n[p] ; i ++ )
	{
		base[i + 1] = base[i] * 233;
		while(~j && s[p][i] != s[2][j]) j = next[j];
		if(++j == n[2]) sa[p][i - j + 1] ++ , sb[p][i] ++ , j = next[j];
	}
	for(i = n[p] - 2 ; ~i ; i -- ) sa[p][i] += sa[p][i + 1];
	for(i = 1 ; i < n[p] ; i ++ ) sb[p][i] += sb[p][i - 1];
}
bool judge(int mid)
{
	int i;
	ull v = 0;
	mp.clear();
	for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[0][i];
	for(i = mid - 1 ; i < n[0] ; i ++ )
	{
		v = v * 233 + s[0][i];
		if(sa[0][i - mid + 1] + sb[0][i] <= sa[0][0]) mp.insert(v);
		v -= s[0][i - mid + 1] * base[mid - 1];
	}
	v = 0;
	for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[1][i];
	for(i = mid - 1 ; i < n[1] ; i ++ )
	{
		v = v * 233 + s[1][i];
		if(sa[1][i - mid + 1] + sb[1][i] <= sa[1][0] && mp.count(v)) return 1;
		v -= s[1][i - mid + 1] * base[mid - 1];
	}
	return 0;
}
int main()
{
	int i , j , l , r , mid , ans = 0;
	for(i = 0 ; i < 3 ; i ++ ) scanf("%s" , s[i]) , n[i] = strlen(s[i]);
	next[0] = -1;
	for(i = 1 , j = -1 ; i <= n[2] ; i ++ )
	{
		while(~j && s[2][j] != s[2][i - 1]) j = next[j];
		next[i] = ++j;
	}
	base[0] = 1 , kmp(0) , kmp(1);
	l = 1 , r = min(n[0] , n[1]);
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(judge(mid)) ans = mid , l = mid + 1;
		else r = mid - 1;
	}
	printf("%d\n" , ans);
	return 0;
}

 

posted @ 2017-12-25 14:32  GXZlegend  阅读(503)  评论(0编辑  收藏  举报