题意
给定n个字符串,求所有字符串前缀与后缀相等的个数与前后缀的长度的平方的和。如样例,匹配长度为1,2,3的分别有4,4,1个,所以答案为
4 *1 ^2 +4 *2 ^2 +1 *3 ^2=29
思路
可以枚举每一个字符的前缀看该前缀能和那些后缀组成f(t, s),然后在减去枚举时多出的部分即可。构建AC自动机,利用fail数组构建成
一个树(fail的每条边反向就可以构成一颗树),利用这棵树来对cnt数组进行累加。每次遍历一个字符串的前缀时构建一个KMP的ne数组,
答案累加cnt[p]*len*len(cnt[p]为该前缀能与多少后缀匹配,len是该前缀的长度)在减去cnt[p]*ne[j]*ne[j](表示ne[j]这个前缀和
自己的后缀(j结尾)匹配成功的情况)
代码
#pragma GCC optimize(2)
#include<unordered_map>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define Buff ios::sync_with_stdio(false)
#define rush() int Case = 0; int T; cin >> T; while(T--)
#define rep(i, a, b) for(int i = a; i <= b; i ++)
#define per(i, a, b) for(int i = a; i >= b; i --)
#define reps(i, a, b) for(int i = a; b; i ++)
#define clc(a, b) memset(a, b, sizeof(a))
#define Buff ios::sync_with_stdio(false)
#define readl(a) scanf("%lld", &a)
#define readd(a) scanf("%lf", &a)
#define readc(a) scanf("%c", &a)
#define reads(a) scanf("%s", a)
#define read(a) scanf("%d", &a)
#define lowbit(n) (n&(-n))
#define pb push_back
#define lson rt<<1
#define rson rt<<1|1
#define ls lson, l, mid
#define rs rson, mid+1, r
#define y second
#define x first
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>PII;
const int mod = 998244353;
const double eps = 1e-6;
const int N = 1e6+1;
const int M = 1e5+7;
string s[M];
ll tr[N][26], ne[N], q[N], idx;
ll cnt[N], nxt[N], id[N];
vector<int> G[N];
void insert(int x)
{
int p = 0;
reps(i, 1, s[x][i])
{
int t = s[x][i] - 'a';
if(!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p] ++;
id[x] = p;
}
void build()
{
int hh = 0, tt = -1;
rep(i, 0, 25) if(tr[0][i]) q[++ tt] = tr[0][i];
while(hh <= tt)
{
int t = q[hh ++];
rep(i, 0, 25)
{
ll& p = tr[t][i];
// printf("tr[%d][%d]: %lld ----- %lld\n", t, i, tr[t][i], tr[ne[t]][i]);
if(!p) p = tr[ne[t]][i];
else
{
ne[p] = tr[ne[t]][i];
q[++ tt] = p;
}
}
}
}
void dfs(int x)
{
sort(G[x].begin(), G[x].end());
G[x].erase(unique(G[x].begin(), G[x].end()), G[x].end());
for(int t:G[x])
{
dfs(t);
cnt[x] += cnt[t];
}
}
void kmp(int x)
{
for(int i = 2, j = 0; s[x][i]; i ++)
{
while(j && s[x][i] != s[x][j+1]) j = nxt[j];
if(s[x][i] == s[x][j+1]) j ++;
nxt[i] = j;
}
}
int main()
{
int n;
cin >> n;
rep(i, 1, n)
{
cin >> s[i]; s[i] = " " + s[i];
insert(i);
}
build();
for(int i = 1; i <= n; i ++)
{
int p = id[i];
while(p)
{
G[ne[p]].push_back(p);
p = ne[p];
}
}
dfs(0);
ll res = 0;
// for(int i = 1; i <= idx; i ++)
// cout << cnt[i] <<"\n";
for(int i = 1; i <= n; i ++)
{
kmp(i);
int p = 0;
for(int j = 1; s[i][j]; j ++)
{
p = tr[p][s[i][j]-'a'];
// cout <<"i, j: " << i <<" "<< j <<" "<< (cnt[p] * j * j) % mod - (cnt[p] * nxt[j] * nxt[j]) <<endl;
res = (res + (cnt[p] * j * j)%mod - (cnt[p]*nxt[j]*nxt[j])%mod)%mod;
if(res < 0) res += mod;
}
}
cout << res <<endl;
return 0;
}