BZOJ4650 [Noi2016]优秀的拆分
题意:
若干次询问一个串的每个子串的拆分成\(AABB\)式的方案总数。
知识点:
后缀数组,调和思想,差分,思维
解法:
\(a[],b[]\)分别为以当前开始结束的\(AA\)串的方案数,那么所有的\(b_i\times a_{i+1}\)的和就是答案。又因为就这么求有长度的限制不好求,所以化为关键点思想,从\(1\)到\(\frac{n}{2}\)分别设立关键点,就是每个\(len\)个位置一个点,那么相邻\(2\)个关键点对答案的贡献是可以求的。
具体地,对于相邻两个关键点\(i\)和\(j\),\(i\)和\(j\)分别向前的、向后的公共部分求出来(注意前面不超过\(len-1\),后面不超过\(len\),否则会重复计算贡献),如果\(lcs\)和\(lcp\)的和小于\(len\),那没有重复部分肯定无解。否则,把重复那段左移到此时最左边就是其中一段开头,最右边就是一段结尾。差分一下就可以了。
备注:
注意:1,一定要封装SA。2,此题有非常多的细节值得斟酌。3,这个blog的图解释得十分清楚:https://gypsophila.blog.luogu.org/solution-p1117
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=300010;
ll ans,a[maxn],b[maxn];
int n,T,mlog[maxn];
struct SA
{
char s[maxn];
int m,sa[maxn],rnk[maxn],height[maxn],c[maxn],y[maxn],f[20][maxn];
void clear()
{
m=26;
memset(f,0x3f,sizeof(f));
for (register int i=0;i<=n+5;i++)
sa[i]=rnk[i]=height[i]=y[i]=0;
}
void rsort()
{
int i;
for (i=0;i<=m+1;i++)
c[i]=0;
for (i=1;i<=n;i++)
c[rnk[i]]++;
for (i=2;i<=m;i++)
c[i]+=c[i-1];
for (i=n;i>=1;i--)
sa[c[rnk[y[i]]]--]=y[i],y[i]=0;
}
void getsa()
{
int i,k,p=0;
for (i=1;i<=n;i++)
rnk[i]=s[i]-'a'+1,y[i]=i;
rsort();
for (k=1;k<n&&p<n;k<<=1,m=p)
{
p=0;
for (i=n-k+1;i<=n;i++)
y[++p]=i;
for (i=1;i<=n;i++)
if (sa[i]>k)
y[++p]=sa[i]-k;
rsort();
swap(rnk,y);
rnk[sa[p=1]]=1;
for (i=2;i<=n;i++)
rnk[sa[i]]=(y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k])?p:(++p);
}
}
void getheight()
{
int i,j,k=0;
for (i=1;i<=n;i++)
{
if (rnk[i]==1)
continue;
if (k)
k--;
j=sa[rnk[i]-1];
while (i+k<=n&&j+k<=n&&s[i+k]==s[j+k])
k++;
height[rnk[i]]=k;
}
}
void init()
{
int i,j;
for (i=1;i<=n;i++)
f[0][i]=height[i];
for (j=1;(1<<j)<=n;j++)
for (i=1;i+(1<<j)-1<=n;i++)
f[j][i]=min(f[j-1][i],f[j-1][i+(1<<(j-1))]);
}
int rmq(int l,int r)
{
int k=mlog[r-l+1];
return min(f[k][l],f[k][r-(1<<k)+1]);
}
}A,B;
int LCP(int i,int j)
{
i=A.rnk[i],j=A.rnk[j];
if (i>j)
swap(i,j);
i++;
return A.rmq(i,j);
}
int LCS(int i,int j)
{
i=n+1-i,j=n+1-j;
i=B.rnk[i],j=B.rnk[j];
if (i>j)
swap(i,j);
i++;
return B.rmq(i,j);
}
int main()
{
int i,j,len,lcs,lcp,Len;
scanf("%d",&T);
for (i=2;i<=300000;i++)
mlog[i]=mlog[i>>1]+1;
while (T--)
{
scanf("%s",A.s+1);
n=strlen(A.s+1);
A.clear(),B.clear();
ans=0;
for (i=0;i<=n+1;i++)
a[i]=b[i]=0;
for (i=1;i<=n;i++)
B.s[i]=A.s[n+1-i];
A.getsa();
B.getsa();
A.getheight();
B.getheight();
A.init();
B.init();
for (len=1;len<=(n>>1);len++)
for (i=len,j=i+len;j<=n;i+=len,j+=len)
if (A.s[i]==A.s[j])
{
lcp=min(len,LCP(i,j)),lcs=min(len-1,LCS(i-1,j-1));
if (lcp+lcs>=len)
{
Len=lcs+lcp-len+1;
a[i-lcs]++,a[i-lcs+Len]--;
b[j+lcp-Len]++,b[j+lcp]--;
}
}
for (i=1;i<=n;i++)
a[i]+=a[i-1],b[i]+=b[i-1];
for (i=1;i<=n-1;i++)
ans+=1ll*b[i]*a[i+1];
printf("%lld\n",ans);
}
return 0;
}