[USACO17DEC]Standing Out from the Herd P

XVII.[USACO17DEC]Standing Out from the Herd P

一个naive的思路就是将所有串拼一起然后后缀排序,找出所有连续的来自同一个串的后缀。考虑结合I.不同子串个数思考,则如果该区间是\([l,r]\)的话,它的贡献应该是\(\sum\limits_{l\leq i\leq r}len_{sa_i}-\sum\limits_{l\leq i\leq r+1}ht_i\),其中\(len_{sa_i}\)\(i\)位置的后缀的长度。注意因为这里要求在所有串中都没有出现过,所以\(ht_i\)的区间应是\([l,r+1]\),包括了前一个串和后一个串中的出现。

但是你会发现它有问题。假如你考虑极端情况,即\(l=r\)时,这是有重复部分的。

比如说下面举一个例子:

aaaaaa...

aaaab...

aaac...

这是后缀数组中三个连续的串。我们这时想要考虑中间那一个串在其他串中出现过的长度。显然,应该是len(aaaa)=4

但是,两边的\(ht\),一个是\(4\),一个是\(3\)。你要两个都算的话,就算重复了。

因此我们对于上面的式子,还应该减去一个\(\operatorname{LCP}(l-1,r+1)\),即正确的式子应该是\(\sum\limits_{l\leq i\leq r}len_{sa_i}-\sum\limits_{l\leq i\leq r+1}ht_i+\min\limits_{l\leq i\leq r+1}ht_i\),这样才不会重复计算。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,all,id[300100],len[300100];
int x[300100],y[300100],sa[300100],ht[300100],rk[300100],buc[300100],s[300100];
ll res[300100];
char str[300100];
bool mat(int a,int b,int k){
	if(y[a]!=y[b])return false;
	if((a+k<n)^(b+k<n))return false;
	if((a+k<n)&&(b+k<n))return y[a+k]==y[b+k];
	return true;
}
void SA(){
	for(int i=0;i<n;i++)buc[x[i]=s[i]]++;
	for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
	for(int i=n-1;i>=0;i--)sa[--buc[x[i]]]=i;
	for(int k=1;k<n;k<<=1){
		int num=0;
		for(int i=n-k;i<n;i++)y[num++]=i;
		for(int i=0;i<n;i++)if(sa[i]>=k)y[num++]=sa[i]-k;
		for(int i=0;i<=m;i++)buc[i]=0;
		for(int i=0;i<n;i++)buc[x[y[i]]]++;
		for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
		for(int i=n-1;i>=0;i--)sa[--buc[x[y[i]]]]=y[i];
		swap(x,y);
		x[sa[0]]=num=0;
		for(int i=1;i<n;i++)x[sa[i]]=mat(sa[i],sa[i-1],k)?num:++num;
		m=num;
	}
	for(int i=0;i<n;i++)rk[sa[i]]=i;
	for(int i=0,k=0;i<n;i++){
		if(!rk[i])continue;
		if(k)k--;
		int j=sa[rk[i]-1];
		while(i+k<n&&j+k<n&&s[i+k]==s[j+k])k++;
		ht[rk[i]]=k;
	}
}
int main(){
	scanf("%d",&all);
	for(int i=1;i<=all;i++){
		scanf("%s",str);
		m=strlen(str);
		for(int j=0;j<m;j++)len[n]=m-j,id[n]=i,s[n]=str[j]-'a'+1,n++;
		s[n++]=i+26;
	}
	m=all+26;
	SA();
//	for(int i=0;i<n;i++)printf("%d ",id[sa[i]]);puts("");
//	for(int i=0;i<n;i++)printf("%d ",len[sa[i]]);puts("");
//	for(int i=0;i<n;i++)printf("%d ",ht[i]);puts("");
	for(int i=0,LCP=0;id[sa[i]];i++){
		LCP=min(LCP,ht[i]),res[id[sa[i]]]+=len[sa[i]]-ht[i];
		if(id[sa[i]]!=id[sa[i-1]])res[id[sa[i-1]]]-=ht[i],res[id[sa[i-1]]]+=LCP,LCP=ht[i];
	}
	for(int i=1;i<=all;i++)printf("%lld\n",res[i]);
	return 0;
}

posted @ 2021-04-01 10:47  Troverld  阅读(51)  评论(0编辑  收藏  举报