2020牛客多校第二场 A题All with Pairs(字符串hash+KMP)

2020牛客多校第二场 A题All with Pairs(字符串hash+KMP)

All with Pairs

题意:给n个字符串,每个字符串与所有字符串比较一次,ans加上每次比较时,比较字符串前缀与被比较字符串的后缀的最长相同长度的平方,输出ans%998244353。

题解:思路源于一个大佬的博客,感谢大佬

这个题给我拿到手上,只能想出把所有字符串的next数组求出来,然后n方枚举,kmp比较时间约为1e6,但会在在枚举时超时。换个思路,现在题目保证总长度不超过1e6,所以我可以预处理出所有后缀,存入map,但直接存字符串会导致1e6方的空间而超空间,所以我们直接hash字符串,这样就解决了时间问题和空间问题了,现在就是要用字符串后缀hash的map算出答案,我们枚举所有前缀的hash值,加上长度*map中与hash值相等(相等意味着前缀等于后缀)的个数既可。这样算出的结果,还不正确,举个栗子:

字符串:ababa 与 aba 比较时,应该答案是3*3=9。

现在将后面字符串后缀拆开:aba,ba,a;

枚举前面字符串的前缀:

a,后缀中存在,ans+=1;

ab,不存在;

aba,存在,ans+=9;

abab,不存在;

ababa,不存在;

答案是10而不是9,因为我重复计算了,那现在只要考虑去重就行了,aba如何减去前面的a,现在aba与后缀加了答案,后缀也是aba,然后前缀的aba,第一个位置与第三个位置相等,就说明第一个位置的a也与后缀的第3个位置相等也是a,这是不是很明显了,就是kmp的next数组,所以计算答案时,第i位加上了x个答案,是不是第next[i]位就要减去x个答案就行了。

以下是代码部分

#include<iostream>
#include<vector>
#define ll long long
#define PII pair<ll,ll>
#define fr first
#define sc second
#define mp make_pair
using namespace std;
const ll mod=1e9+7;
const int maxsz=3e5+7;
template<typename key,typename val>
class hash_map{public:
  struct node{key u;val v;int next;};
  vector<node> e;
  int head[maxsz],nume,numk,id[maxsz];
  int geths(PII &u){
	int x=(1ll*u.fr*mod+u.sc)%maxsz;
	if(x<0) return x+maxsz;
	return x;
  }
  val& operator[](key u){
	int hs=geths(u);
	for(int i=head[hs];i;i=e[i].next)if(e[i].u==u) return e[i].v;
	if(!head[hs])id[++numk]=hs;
	if(++nume>=e.size())e.resize(nume<<1);
	return e[nume]=(node){u,0,head[hs]},head[hs]=nume,e[nume].v;
  }
  void clear(){
	for(int i=0;i<=numk;i++) head[id[i]]=0;
	numk=nume=0;
  }
};
hash_map<PII,int> ma;
ll n;
ll mo[2]={1000000007,998244353};
ll base[2]={43,47};
ll has[3][1000007],po[2][1000007];
ll nex[1000007];
ll vis[1000007];
string s[100007];
pair<ll,ll>tmp,qry;
void init(){
	po[0][0]=1,po[1][0]=1;
	for(int i=1;i<1000001;i++){
		po[0][i]=(po[0][i-1]*base[0])%mo[0];
		po[1][i]=(po[1][i-1]*base[1])%mo[1];
	}
}
void gethash(string s){
	int len=s.length(),x;
	for(int i=0;i<len;i++){
		x=s[i];
		has[0][i+1]=(has[0][i]*base[0]%mo[0]+x)%mo[0];
		has[1][i+1]=(has[1][i]*base[1]%mo[1]+x)%mo[1];
	}
}
ll getv(int l,int r,int k){
	return (has[k][r]-has[k][l-1]*po[k][r-l+1]%mo[k]+mo[k])%mo[k];
}
void kmp(string s){
	int len=s.length();
	int l=0,r=1;
	nex[l]=-1;
	while(r<len){
		if(l==-1||s[l]==s[r]){
			r++;l++;
			nex[r]=l;
		}
		else l=nex[l];
	}
}
int main(){
	init();
	scanf("%lld",&n);
	for(int i=1;i<=n;i++){
		cin>>s[i];
		gethash(s[i]);
		int len=s[i].length();
		for(int j=0;j<len;j++){
			ma[{getv(j+1,len,0),getv(j+1,len,1)}]++;
		}
	}
	ll ans=0;
	for(int i=1;i<=n;i++){
		int len=s[i].length();
		gethash(s[i]);
		kmp(s[i]);
		for(int j=0;j<=len;j++)vis[j]=0;
		for(int j=len;j>=1;j--){
			ans=(ans+(ma[{getv(1,j,0),getv(1,j,1)}]-vis[j]+mo[1])%mo[1]*j%mo[1]*j%mo[1])%mo[1];
			vis[nex[j]]=(vis[nex[j]]+ma[{getv(1,j,0),getv(1,j,1)}])%mo[1];
		}
	}
	printf("%lld\n",ans);
}

posted @ 2020-07-15 14:58  ccsu_madoka  阅读(178)  评论(0编辑  收藏  举报