zoj 2587 kmp的应用

Description

Long long ago, there was a coder named Marlon. One day he picked two string on the street. A problem suddenly crash his brain...

Let Si..j denote the i-th character to the j-th character of string S.

Given two strings S and T. Return the amount of tetrad (a,b,c,d) which satisfy Sa..b + Sc..d = T , ab and cd.

The operator + means concate the two strings into one.

Input

The first line of the data is an integer Tc. Following Tc test cases, each contains two line. The first line is S. The second line is T. The length of S and T are both in range [1,100000]. There are only letters in string S and T.

Output

For each test cases, output a line for the result.

Sample Input

1
aaabbb
ab

Sample Output

9

 

给你两个字符串S,T,求整数对(a,b,c,d)满足Sa..b + Sc..d = T的个数。

思路:先求能组成长度为i的T前缀的S子串个数a1[i],再计算能组成长度为i的T后缀的S子串个数a2[i],然后累计就可以了。关于怎么求a1[],a2[]这两个数组,可以把T和S两个字符串合并,求得next数组,再计算a1数组。同理,把T的翻转与S的翻转合并,计算next和a2数组。

注意:为了避免字符串合并之后T对S产生影响,可以在合并时在中间添加一个无关字符,例如"*",相当于起到隔断作用。

比如

$$T: aaa  $$

$$ S: aaaaaaa$$ 

把它们合并之后:$$aaa*aaaaaaa$$

如果是合并成$aaaaaaaaaa$的话,那么组成T前缀的S子串可能包含T字符串中的字符。

 

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
using namespace std;
const int MAXN=1000100 ;
int jump[MAXN];
void PrintJump(char *str)
{
    jump[0] = -1;
    jump[1] = 0;
    int len = strlen(str+1);
    int j=1,k;
    while(j<len)
    {
        k = jump[j];
        while(k!=-1&&str[k+1]!=str[j+1]) k = jump[k];
        jump[j+1] = k+1;
        j++;
    }
}

char s1[MAXN],s2[MAXN];
char s[MAXN*2];
long long a1[MAXN],a2[MAXN];
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        memset(a1,0,sizeof(a1));
        memset(a2,0,sizeof(a2));
        scanf("%s%s",s1+1,s2);
        s1[0]='*';
        int l1=strlen(s1),l2=strlen(s2);
        for(int i=0;i<l2;i++) s[i+1]=s2[i];
        for(int i=0;i<=l1;i++) s[i+l2+1]=s1[i];
        int l=l1+l2;
        //puts(s+1);
        PrintJump(s);
        for(int i=l2+2;i<=l;i++) {
            int t=jump[i];
            while(t>l2) t=jump[t];
            jump[i] = t;
            a1[t]++;
        }
        for(int i=l2;i>1;i--) {
            a1[jump[i]]+=a1[i];
        }

        for(int i=0;i<l2;i++) s[i+1]=s2[l2-i-1];
        for(int i=1;i<l1;i++) s[i+l2+1]=s1[l1-i];
        PrintJump(s);
        //puts(s+1);
        for(int i=l2+2;i<=l;i++) {
            int t=jump[i];
            while(t>l2) t=jump[t];
            jump[i] = t;
            a2[t]++;
        }
        for(int i=l2;i>1;i--) {
            a2[jump[i]]+=a2[i];
        }

        long long ans=0;
        for(int i=1;i<l2;i++) {
            ans+=a1[i]*a2[l2-i];
        }
        printf("%lld\n",ans);
    }
    return 0;
}
posted @ 2016-03-24 16:12  ZhMZ  阅读(305)  评论(0编辑  收藏  举报