2020牛客暑期多校训练营(第二场) A All with Pairs
思路:首先将所有后缀的hash值求出来,并对每个后缀出现的次数计数, 之后枚举每个串的前缀, 假设串 a 的存在对应后缀的前缀为 s1, s2, s3, |s1| < |s2| < |s3|,
假设 s3 对应串1,串2,串4 的后缀,首先 ans += cnt[s3], 然后看 s2, 若 s2 是 s3 的后缀, 则ans += cnt[s2] - cnt[s3], 因为 s3 是串1,串2,串4 的后缀,
s2 是 s3 的后缀, 那么 s2 也是串1,串2,串4 的后缀,然后 串1,串2,串4 只和 串 a 的 s3 计算贡献, 所以cnt[s2] 要减去 cnt[s3], 若 s2 不是 s3 的后缀,
则 ans += cnt[2], s1 同理。至于判断 s2 是否是 s3 的后缀, KMP 即可。
#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <string>
#include <string.h>
#include <map>
#include <iostream>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> pii;
const int maxn = 1e6 + 50;
const ULL mod = 998244353;
const LL p = 233;
const int mod2 = 998244353;
int INF = 1e9;
#define fi first
#define se second
string s[maxn];
string T;
LL Hash[maxn];
LL pp[maxn];
map<ULL, LL> mmap;
LL cnt[maxn];
int Next[maxn];
int slen, tlen;
void getNext()
{
int j, k;
j = 0; k = -1; Next[0] = -1;
while(j < tlen){
if(k == -1 || T[j] == T[k]) Next[++j] = ++k;
else k = Next[k];
}
}
int main(int argc, char const *argv[])
{
int n;
scanf("%d", &n);
pp[0] = 1;
for(int i = 1; i < maxn; i++){
pp[i] = 1LL * pp[i - 1] * p % mod;
}
for(int i = 1; i <= n; i++){
cin >> s[i];
int len = s[i].size();
ULL base = 1;
ULL hs = 0;
for(int j = len - 1; j >= 0; --j){
hs += base * s[i][j];
mmap[hs]++;
base *= p;
}
}
LL ans = 0;
for(int i = 1; i <= n; i++){
T = s[i];
tlen = T.size();
ULL hs = 0;
for(int j = 0; j < tlen; j++){
cnt[j] = 0;
hs = hs * p + s[i][j];
if(!mmap.count(hs)) continue;
cnt[j] += mmap[hs];
}
getNext();
for(int j = 1; j <= tlen; j++){
if(Next[j] == 0) continue;
cnt[Next[j] - 1] -= cnt[j - 1];
}
for(int j = 0; j < tlen; j++){
ans = (ans + 1LL * (j + 1) * (j + 1) % mod2 * cnt[j] % mod2) % mod2;
}
}
cout << ans << endl;
return 0;
}