HDU 6153 KMP 模式串后缀在目标串中的出现次数
HDU6153:http://acm.hdu.edu.cn/showproblem.php?pid=6153
Orz 果然字符串很神奇……(神奇到人哭出来啊TAT)果然算法理解的不够透彻的话,是无法灵活运用的QAQ
其实只要简单的改一点点KMP的地方……
KMP的next[i]存储的是模式串下标0至i处相同的前缀和后缀的最长长度,每次与目标串匹配的时候,会略过前后缀相同的部分以减少回溯,提高效率。
由于每次匹配都是从前向后匹配的,因此对于模式串每次的匹配子串都是[0……i],是前缀,但是根据题目的要求,我们要找的是s2的子串[i……n],是后缀。
其实只要把模式串reverse一下,就可以把后缀转化为前缀,然后在进行处理就方便了。
对kmp理解的比较透彻的就不要听我瞎扯了……写的比较……emmmmm……不过还是很欢迎给窝这个连门槛都摸不到的小白提提建议的(^∇^*)
可以用一个ll数组记录长度为i的前缀(也就是转化前的后缀)出现的次数(这里数组用crt表示)。(用ll……不要用int……会TLE)
因为next数组本身记录的就是相同子串的长度,因此可以在每个循环内部对匹配过的字符的下标进行计数
先码个代码:
while (i < len1) { if(j==-1||s1[i]==s2[j]) { i++; j++; } else j = nextt[j]; num[j]++; if (j == len2) j = nextt[j]; }
下面拿题目的例子来解释一下出现次数的计数:
b a b a b a b a
next -1 0 0 1 0 0 0 0
a ——模式串目标串匹配过程
crt 0 0 0 0 匹配失败,j=next[j]=-1;这里crt数组不更新
crt 1 0 0 0
a
crt 1 1 0 0
a b
crt 1 1 1
a b a ——j==s2长度,但由于我们要在整个目标串中匹配,这里不能跳出,要继续匹配
crt 1 1 1 1 因此又加了 if(j==len2) j=next[j];
a b
crt 1 1 2 1
a b a
crt 1 1 2 2
…………
a b a
crt 1 1 3 3
可以发现,把crt计数放在每次j处理过之后,下标刚好就为目前子串的长度,所以就可以安心的记录啦;
前面有提到,kmp的算法为了减少回溯,省略了相同部分的匹配,因此我们每次记录的只是当前相同前缀长度的个数,从上面的例子不难发现,
除了第一次‘a'以外的其他的'a'的匹配都被省略了,其他都是从’ab'开始匹配的,每次匹配失败再次匹配时,省略了next[i]之前部分的计数,因此还需要进一步对计数数组进行处理。
for (int i = len2; i >0; i--) crt[nextt[i]] += crt[i];
既然少了加回来就好了,匹配失败以后再次匹配时,都是从next[i]开始匹配的,因此,crt[next[i]]未计数的部分=crt[i],循环把len~1的部分加回来就OK了
(记得是从len倒着往前加)
下面看完整的代码。
#include<iostream> #include<algorithm> using namespace std; const int MAX = 1e6; const int mod = 1e9 + 7; int nextt[MAX + 10]; char s1[MAX + 10]; char s2[MAX + 10]; long long num[MAX + 10]; int t,len1,len2; void getnext() { int p=0, k = -1; nextt[0] = -1; while (p < len2) { if (k == -1 || s2[p] == s2[k]) { p++;k++; nextt[p] =k; } else k = nextt[k]; } } void kmp() { getnext(); int i = 0, j = 0; while (i < len1) { if(j==-1||s1[i]==s2[j]) {i++; j++;} else j = nextt[j]; num[j]++; if (j == len2) j = nextt[j]; } } int main() { cin >> t; while (t--) { memset(num,0, sizeof(num)); cin >> s1; cin >> s2; len1 = strlen(s1); len2 = strlen(s2); reverse(s1, s1 + len1); reverse(s2, s2 + len2); kmp(); long long ans = 0; for (int i = len2; i > 0; i--) { num[nextt[i]] += num[i]; ans = (ans + i*num[i]) % mod; } printf("%d\n", ans); } return 0; }
有什么不对的地方或者解释的不好的地方还请大家指出来( ̄▽ ̄)"