2020牛客暑期多校训练营(第二场) A All with Pairs

思路:首先将所有后缀的hash值求出来,并对每个后缀出现的次数计数, 之后枚举每个串的前缀, 假设串 a 的存在对应后缀的前缀为 s1, s2, s3, |s1| < |s2| < |s3|,
假设 s3 对应串1,串2,串4 的后缀,首先 ans += cnt[s3], 然后看 s2, 若 s2 是 s3 的后缀, 则ans += cnt[s2] - cnt[s3], 因为 s3 是串1,串2,串4 的后缀,
s2 是 s3 的后缀, 那么 s2 也是串1,串2,串4 的后缀,然后 串1,串2,串4 只和 串 a 的 s3 计算贡献, 所以cnt[s2] 要减去 cnt[s3], 若 s2 不是 s3 的后缀,
则 ans += cnt[2], s1 同理。至于判断 s2 是否是 s3 的后缀, KMP 即可。

#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <string>
#include <string.h>
#include <map>
#include <iostream>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> pii;
const int maxn = 1e6 + 50;
const ULL mod = 998244353;
const LL p = 233;
const int mod2 = 998244353;
int INF = 1e9;

#define fi first
#define se second
string s[maxn];
string T;
LL Hash[maxn];
LL pp[maxn];
map<ULL, LL> mmap;
LL cnt[maxn];
int Next[maxn];
int slen, tlen;
void getNext()
{
    int j, k;
    j = 0; k = -1; Next[0] = -1;
    while(j < tlen){
        if(k == -1 || T[j] == T[k]) Next[++j] = ++k;
        else k = Next[k];
    }
}

int main(int argc, char const *argv[])
{
    int n;
    scanf("%d", &n);
    pp[0] = 1;
    for(int i = 1; i < maxn; i++){
        pp[i] = 1LL * pp[i - 1] * p % mod;
    }
    for(int i = 1; i <= n; i++){
        cin >> s[i];
        int len = s[i].size();
        ULL base = 1;
        ULL hs = 0;
        for(int j = len - 1; j >= 0; --j){
            hs += base * s[i][j];
            mmap[hs]++;
            base *= p;
        }
    }

    LL ans = 0;
    for(int i = 1; i <= n; i++){
        T = s[i];
        tlen = T.size();
        ULL hs = 0;
        for(int j = 0; j < tlen; j++){
            cnt[j] = 0;
            hs = hs * p + s[i][j];
            if(!mmap.count(hs)) continue;
            cnt[j] += mmap[hs];
        }
        getNext();
        for(int j = 1; j <= tlen; j++){
            if(Next[j] == 0) continue;
            cnt[Next[j] - 1] -= cnt[j - 1];
        }
        for(int j = 0; j < tlen; j++){
            ans = (ans + 1LL * (j + 1) * (j + 1) % mod2 * cnt[j] % mod2) % mod2;
        }
    }

    cout << ans << endl;
    return 0;
}
posted @ 2020-07-14 17:35  从小学  阅读(159)  评论(0编辑  收藏  举报