P5410 【模板】扩展 KMP(Z 函数)
题目链接
P5410 【模板】扩展 KMP(Z 函数)
【模板】扩展 KMP(Z 函数)
题目描述
给定两个字符串 \(a,b\),你要求出两个数组:
- \(b\) 的 \(z\) 函数数组 \(z\),即 \(b\) 与 \(b\) 的每一个后缀的 LCP 长度。
- \(b\) 与 \(a\) 的每一个后缀的 LCP 长度数组 \(p\)。
对于一个长度为 \(n\) 的数组 \(a\),设其权值为 \(\operatorname{xor}_{i=1}^n i \times (a_i + 1)\)。
输入格式
两行两个字符串 \(a,b\)。
输出格式
第一行一个整数,表示 \(z\) 的权值。
第二行一个整数,表示 \(p\) 的权值。
样例 #1
样例输入 #1
aaaabaa
aaaaa
样例输出 #1
6
21
提示
样例解释:
\(z = \{5\ 4\ 3\ 2\ 1\}\),\(p = \{4\ 3\ 2\ 1\ 0\ 2\ 1\}\)。
数据范围:
对于第一个测试点,\(|a|,|b| \le 2 \times 10^3\)。
对于第二个测试点,\(|a|,|b| \le 2 \times 10^5\)。
对于 \(100\%\) 的数据,\(1 \le |a|,|b| \le 2 \times 10^7\),所有字符均为小写字母。
解题思路
z函数
z函数即exkmp
对于一个长度为 \(n\) 的字符串 \(s\),定义函数 \(z[i]\) 表示 \(s\) 和 \(s[i,n-1]\) 的的最长公共前缀长度,注意后面算法求z函数时特定 \(z[0]=0\)
算法流程:主要核心在求 \(z[i]\) 时要利用前面 \(z[0],z[1],\dots,z[i-1]\) 的信息,从 \(i=1\) 开始,对于每个 \(i\) 来说,都有一个匹配段 \([i,r]\),维护其中 \(r\) 最大的匹配段 \([l,r]\),如果 \(i\leq r\),则有 \(s[i,r]=s[i-l,r-l]\),则有 \(z[i]\geq min(z[i-l],r-i+1)\),如果 \(z[i-l]<r-i+1\),则 \(z[i]=z[i-l]\),否则令 \(z[i]=r-i+1\),暴力往后扩展;否则如果 \(i>r\),则暴力求解
这个算法只能求解 \(s\) 本身的z函数,\(\color{red}{如果需要求解s和t的每一个后缀的最长公共长度?}\)不妨令\(s=s+t\),从 \(len(s)\) 的位置开始计算z函数,另外需要注意 \(z[i]\leq len(原s)\)
复杂度分析:每次 while
时,\(r\) 都至少增加一次,即 while
的总的执行次数为 \(O(n)\),则:
- 时间复杂度:\(O(n)\)
代码
// Problem: P5410 【模板】扩展 KMP(Z 函数)
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5410
// Memory Limit: 500 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
// #define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=4e7+5;
int n,m;
char s[N>>1],t[N];
vector<int> z_function(char *s)
{
int n=strlen(s);
vector<int> z(n);
for(int i=1,l=0,r=0;i<n;i++)
{
if(i<=r&&z[i-l]<r-i+1)z[i]=z[i-l];
else
{
z[i]=max(0,r-i+1);
while(i+z[i]<n&&s[z[i]]==s[i+z[i]])z[i]++;
}
if(i+z[i]-1>r)l=i,r=i+z[i]-1;
}
return z;
}
int main()
{
scanf("%s%s",s,t);
n=strlen(t);
strcat(t+n,t);
vector<int> z=z_function(t);
LL res=0;
for(int i=n;i<2*n;i++)res^=(LL)(i-n+1)*(min(z[i],n)+1);
printf("%lld\n",res);
t[n]='\0';
m=strlen(s);
strcat(t+n,s);
z=z_function(t);
res=0;
for(int i=n;i<n+m;i++)res^=(LL)(i-n+1)*(min(z[i],n)+1);
printf("%lld\n",res);
return 0;
}