洛谷P3435 [POI2006]OKR-Periods of Words
题目
这题意不是一般人能读懂的,为了读懂题目,我还特意去翻了题解[手动笑哭]
题目大意:
给定一个字符串s
对于\(s\)的每一个前缀子串\(s1\),规定一个字符串\(Q\),\(Q\)满足:\(Q\)是\(s1\)的前缀子串且Q不等于\(s1\)且\(s1\)是字符串\(Q+Q\)的前缀.设\(siz\)为所有满足条件的\(Q\)中\(Q\)的最大长度(注意这里仅仅针对\(s1\)而不是\(s\),即一个\(siz\)的值对应一个\(s1\))
求出所有\(siz\)的和
不要被这句话误导了:
求给定字符串所有前缀的最大周期长度之和
正确断句:求给定字符串 所有/前缀的最大周期长度/之和
我就想了半天:既然是"最大周期长度",那不是唯一的吗?为什么还要求和呢?
思路
其实这题要AC并不难(看通过率就知道)
看图
要满足\(Q\)是\(s1\)的前缀,则\(Q\)的\(1\)~\(5\)位和\(s1\)的1~5位是一样的,又因为\(s1\)是\(Q+Q\)的前缀,所以又要满足\(s1\)的6~8位和\(Q+Q\)的6~8位一样,即\(s1\)的6~8位和Q的1~3位相等,回到\(s1\),标蓝色的两个位置相等.
回顾下KMP中\(next\)数组的定义:next[i]
表示对于某个字符串a,"a中长度为next[i]
的前缀子串"与"a中以第i为结尾,长度为next[i]
的非前缀子串"相等,且next[i]
取最大值
是不是悟到了什么,是不是感觉这题和\(next\)数组冥冥之中有某种相似之处?
但是,这仅仅只是开始
按照题目的意思,我们要让\(Q\)的长度最大,也就是图中蓝色部分长度最小,但是\(next\)中存的是蓝色部分的最大值,显然,两者相违背,难道我们要改造\(next\)数组吗?明显不行,若\(next\)存储的改为最小值,则原来求\(next\)的方法行不通.考虑换一种思路(一定要对KMP中\(next\)的求法理解透彻,不然下面看不懂,不行的复习一下),我们知道对于next[i],next[next[i-1]],next[next[next[i]]]...
都能满足"前缀等于以\(i\)结尾的子串"这个条件,且越往后,值越小,所以,我们的目标就定在上面序列中从后往前第一个不为0的\(next\)值
极端条件下,暴力跑可以去到\(O(n^2)\),理论上会超时(我没试过)
两种优化:
- 记忆化,时间效率应该是O(n)这里不详细讲,可以去到洛谷题解查看
- 倍增(我第一时间想到并AC的做法):
我们将j=next[j]
这一语句称作"j跳了一次"(感觉怪怪的),将next拓展为2维,next[i][k]
表示结尾为i,j跳了2^k的前缀字符长度(也就是next[i][0]
等价于原来的next[i]
)
借助倍增LCA的思想(没学没关系,现学现用),这里不做赘述,上代码
int tmp = i;
for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一个为0的小标j,注意倒序枚举
if(next[tmp][j] != 0)//如果不为0则跳
tmp = next[tmp][j];
倍增方法在字符串长度去到\(10^6\)时是非常危险的,带个\(\log\)理论是\(2\cdot 10^7\)左右,常数再大那么一丢丢就TLE了,还好数据比较水,但是作为倍增和KMP的练习做一下也是不错的
最后,记得开longlong(不然我就一次AC了)
完整代码
#include <iostream>
#include <cmath>
#include <cstdio>
#define nn 1000010
#define rr register
#define ll long long
using namespace std;
int next[nn][30] ;
int siz[nn];
char s[nn];
int n;
int main() {
// freopen("P3435_3.in" , "r" , stdin);
cin >> n;
do
s[1] = getchar();
while(s[1] < 'a' || s[1] > 'z');
for(rr int i = 2 ; i <= n ; i++)
s[i] = getchar();
next[1][0] = 0;
for(rr int i = 2 , j = 0 ; i <= n ; i++) {
while(j != 0 && s[i] != s[j + 1])
j = next[j][0];
if(s[j + 1] == s[i])
++j;
next[i][0] = j;
}
rr int k = log(n) / log(2) + 1;
for(rr int j = 1 ; j <= k ; j++)
for(rr int i = 1 ; i <= n ; i++) {
next[i][j] = next[next[i][j - 1]][j - 1];
if(next[i][j] == 0)
siz[i] = j;
}
ll ans = 0;
for(rr int i = 1 ; i <= n ; ++i) {
int tmp = i;
for(rr int j = siz[i] ; j >= 0 ; --j)
if(next[tmp][j] != 0)
tmp = next[tmp][j];
if(2 * (i - tmp) >= i && tmp != i)
ans += (ll)i - tmp;
}
cout << ans;
return 0;
}