P1872 回文串计数(回文树)
题目描述
小a虽然是一名理科生,但他常常称自己是一名真正的文科生。不知为何,他对于背诵总有一种莫名其妙的热爱,这也促使他走向了以记忆量大而闻名的生物竞赛。然而,他很快发现这并不能满足他热爱背诵的心,但是作为一名强大的OIER,他找到了这么一个方法——背诵基因序列。然而这实在是太困难了,小啊感觉有些招架不住。不过他发现,如果他能事先知道这个序列里有多少对互不相交的回文串,他或许可以找到记忆的妙法。为了进一步验证这个想法,小a决定选取一个由小写字母构成的字符串SS来实验。由于互不相关的回文串实在过多,他很快就数晕了。不过他相信,在你的面前这个问题不过是小菜一碟。
(1)对于字符串SS,设其长度为Len,那么下文用Si表示SS中第i个字符(1<=i<=Len)。
(2)S[i,j]表示SS的一个子串,S[i,j]="SiSi+1Si+2...Sj-2Sj-1Sj",比如当SS为"abcgfd"时,S[2,5]="bcgf",S[1,5]="abcgf"。
(3)当一个串被称为一个回文串当且仅当将这个串反写后与原串相同,如“abcba”。
(4)考虑一个四元组(l,r,L,R),当S[l,r]和S[L,R]均为回文串时,且满足1<=l<=r<=L<=R<=Len时,我们称S[l,r]和S[L,R]为一对互不相交的回文串。即本题所求,也即为这种四元组的个数。两个四元组相同当且仅当对应的l,r,L,R都相同。
输入输出格式
输入格式:
输入仅一行,为字符串SS,保证全部由小写字母构成,由换行符标志结束。
50%的数据满足SS的长度不超过200;
100%的数据满足SS的长度不超过2000。
输出格式:
仅一行,为一个整数,表示互不相关的回文串的对数。
输入输出样例
输入样例#1: 复制
aaa
输出样例#1: 复制
5
说明
【样例数据说明】
SS="aaa",SS的任意一个字符串均为回文串,其中总计有5对互不相关的回文串:
(1,1,2,2),(1,1,2,3),(1,1,3,3),(1,2,3,3),(2,2,3,3)。
题解
在这里我们对回文树引入一个新的标记,dep
它表示的是当前这个节点为回文串时,以它结尾的串中包含的回文串的数量。
好,那么接下来的意思就很简单了。
我们只需要依次递推过去,先更新出原本顺序上每一位的回文串的数量,再更新出倒序上每一位的回文串的数量(其实倒序就是求尾部的回文串数量)。
这样我们每更新一位,就加上之前每一位的回文串的数量。再乘以下一位的按倒序处理出来的回文串的数量就ok了。
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int tot;
ll ans,p1[20001],p2[20001];
struct node{
int fail,ch[26],len,cnt,dep;
}t[200001];
char s[200001];
int read()
{
int x=0,w=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
return x*w;
}
void solve1()
{
int len=strlen(s+1),k=0;s[0]='#';
t[0].fail=t[1].fail=1;t[1].len=-1;tot=1;
for(int i=1;i<=len;i++)
{
while(s[i-t[k].len-1]!=s[i])k=t[k].fail;
if(!t[k].ch[s[i]-'a']){
t[++tot].len=t[k].len+2;
int j=t[k].fail;
while(s[i-t[j].len-1]!=s[i])j=t[j].fail;
t[tot].fail=t[j].ch[s[i]-'a'];
t[k].ch[s[i]-'a']=tot;
t[tot].dep=t[t[tot].fail].dep+1;
}
k=t[k].ch[s[i]-'a'];
p1[i]=t[k].dep;
t[k].cnt++;
}
}
void solve2()
{
int len=strlen(s+1),k=0;s[0]='#';
t[0].fail=t[1].fail=1;t[1].len=-1;tot=1;
for(int i=1;i<=len;i++)
{
while(s[i-t[k].len-1]!=s[i])k=t[k].fail;
if(!t[k].ch[s[i]-'a']){
t[++tot].len=t[k].len+2;
int j=t[k].fail;
while(s[i-t[j].len-1]!=s[i])j=t[j].fail;
t[tot].fail=t[j].ch[s[i]-'a'];
t[k].ch[s[i]-'a']=tot;
t[tot].dep=t[t[tot].fail].dep+1;
}
k=t[k].ch[s[i]-'a'];
t[k].cnt++;
p2[len-i+1]=t[k].dep;
}
}
int main()
{
scanf("%s",s+1);
int len=strlen(s+1);
solve1();
reverse(s+1,s+len+1);
memset(t,0,sizeof(t));
solve2();
for(int i=1;i<=len;i++)p1[i]+=p1[i-1];
for(int i=1;i<=len;i++)ans+=p1[i]*p2[i+1];
printf("%lld",ans);
return 0;
}