jzoj3291. 【JSOI2013】快乐的JYY

Description

给定两个字符串A和B,表示JYY的两个朋友的名字。我们用A(i,j)表示A字符串中从第i个字母到第j个字母所组成的子串。同样的,我们也可以定义B(x,y)。
JYY发现两个朋友关系的紧密程度,等于同时满足如下条件的四元组(i,j,x,y)的个数:

  1. 1≤i≤j≤|A|
  2. 1≤x≤y≤|B|
    3)A(i,j)=B(x,y)
  3. A(i,j)为回文串
    这里|A|表示字符串A的长度。
    JYY希望你帮助他计算出这两个朋友之间关系的紧密程度。

Input

数据包行两行由大写字母组成的字符串A和B。

Output

输出文件包含一行一个整数,表示紧密程度,也就是满足要求的4元组个数。

Sample Input

输入1:
PUPPY
PUPPUP
输入2:
PUPPY
KFC

Sample Output

输出1:
17
输出2:
0

Data Constraint

对于10%的数据满足|A|,|B|≤50;
对于30%的数据满足|A|,|B| ≤ 1000;
对于40%的数据满足|A|≤1000;
对于60%的数据满足|A|≤4000;
对于额外20%的数据满足A=B;
对于100%的数据满足1 ≤|A|,|B| ≤ 50000。

赛时

这题还是有点意思的。
主要是部分分比较多吧,虽说我玩出的分似乎并不是特别多,打了个水法,滚粗。
其实赛时就往过回文树上面想,但是由于太久没打,都快忘光了。(上次打还是初二QWQ)
稍微看了看构树就懂了,确认过眼神,是道板子题。

题解

话说给的题解说的后缀自动机真的很神奇,但是回文树它还是更香些。
大概就是对于两个字符串都分别造出一颗回文树,然后把cnt数组给标记在树上。
接着就两个指针在两颗树上一起跑,能走就走,边走边把cnt乘在一起,加入答案。
还是很easy的,主要是帮助我复习了一下都快丢掉的回文树。

代码

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
const int maxn=50010;

char st[3][maxn],t[maxn];
int c[26],p[maxn];
int nex[3][maxn][26],fail[3][maxn],len[3][maxn],cnt[3][maxn],num[3][maxn],tot[3],s[3][maxn],n[3],las[3];
long long ans;

int insert(int x,int p)
{
	for (int i=0;i<=25;i++)
	{
		nex[p][tot[p]][i]=0;
	}
	cnt[p][tot[p]]=0;
	num[p][tot[p]]=0;
	len[p][tot[p]]=x;
	tot[p]++;
	return tot[p];
}

int get_fail(int x,int p)
{
	while (s[p][n[p]-len[p][x]-1]!=s[p][n[p]]) x=fail[p][x];
	return x;
}

void init()
{
	insert(0,1);
	insert(-1,1);
	insert(0,2);
	insert(-1,2);
	tot[1]=tot[2]=2;
	n[1]=n[2]=0;
	las[1]=las[2]=0;
	s[1][n[1]]=s[2][n[2]]=-1;
	fail[1][0]=fail[2][0]=1;
}

void add(int wz,int p)
{
	s[p][++n[p]]=wz;
	int cur=get_fail(las[p],p);
	if (nex[p][cur][wz]==0)
	{
		insert(len[p][cur]+2,p);
		int now=tot[p]-1;
		fail[p][now]=nex[p][get_fail(fail[p][cur],p)][wz];
		nex[p][cur][wz]=now;
		num[p][now]=num[p][fail[p][now]]+1;
	}
	las[p]=nex[p][cur][wz];
	cnt[p][las[p]]++;
}

void count()
{
	for (int p=1;p<=2;p++)
	for (int i=tot[p]-1;i>=0;i--)
	{
		cnt[p][fail[p][i]]+=cnt[p][i];
	}
}

void dfs(int x,int y)
{
	if (x+y>2)
	{
		long long op=cnt[1][x];
		long long oq=cnt[2][y];
		ans=ans+op*oq;
	}
	for (int i=0;i<=25;i++)
	{
		if (nex[1][x][i]>0 && nex[2][y][i]>0)
		{
			dfs(nex[1][x][i],nex[2][y][i]);
		}
	}
}

int main()
{
	scanf("%s",st[1]+1);
	scanf("%s",st[2]+1);
	init();
	for (int p=1;p<=2;p++)
	{
		for (int i=1;i<=strlen(st[p]+1);i++)
		{
			add(st[p][i]-'A',p);
		}
	}
	count();
	dfs(0,0);
	dfs(1,1);
	printf("%lld\n",ans);
}
posted @ 2020-08-12 20:16  RainbowCrown  阅读(113)  评论(0编辑  收藏  举报