牛客多校第二场
A - All with Pairs (AC自动机+kmp)
题意
定义\(f(s, t)\)为字符串s的前缀和t的后缀的最长公共长度。给定n个字符串\(s_1,s_2,...,s_n\),\(\sum_{i=1}^{n}{\sum_{j=1}^{n}{f(s_i, s_j)^2}}(\mod 998244353)\)
思路
AC自动机的trie树的从根节点开始的路径都是某个单词的前缀,fail指针跳的是当前前缀的后缀。所以可以预处理出每个前缀的贡献(记录在trie树上)。
然后匹配一个由所有单词连接而成的文本串,用一个特殊字符隔开。每次匹配到特殊字符,说明已经到达当前单词的结尾。然后跳fail累加前缀的贡献即可。
由于跳fail会导致重复累加贡献,所以一开始需要kmp预处理每个单词的前缀贡献的差分。
#include <bits/stdc++.h>
#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 1e6 + 10;
const int M = 998244353;
int tr[N][30];
int fail[N * 30];
int si = 0;
string s;
string tex;
ll add[N];
ll cont[N * 30];
int nt[N];
void getnext(const char t[]) {
int i = 0, j = -1;
nt[i] = j;
int n = 0;
while(t[i]) {
n++;
if(j == -1 || t[i] == t[j]) {
i++;
j++;
nt[i] = j;
} else {
j = nt[j];
}
}
for(int i = n; i >= 1; i--) {
add[i] = ((1ll * i * i - 1ll * nt[i] * nt[i]) % M + M) % M;
}
}
namespace AC {
void init() {
memset(tr[0], 0,sizeof tr[0]);
memset(cont, 0,sizeof cont);
si = 0;
}
void insert(const char s[]) {
int cur = 0;
fail[cur] = 0;
for(int i = 0; s[i]; i++) {
if(tr[cur][s[i] - 'a']) cur = tr[cur][s[i] - 'a'];
else {
tr[cur][s[i] - 'a'] = ++si;
cur = si;
memset(tr[si], 0, sizeof tr[si]);
fail[cur] = 0;
cont[cur] = 0;
}
cont[cur] += add[i + 1];
cont[cur] %= M;
}
}
void build() {
queue<int> q;
for(int i = 0; i < 27; i++)
if(tr[0][i]) q.push(tr[0][i]);
while(!q.empty()) {
int cur = q.front();
q.pop();
for(int i = 0; i < 26; i++) {
if(tr[cur][i]) {
fail[tr[cur][i]] = tr[fail[cur]][i];
q.push(tr[cur][i]);
} else {
tr[cur][i] = tr[fail[cur]][i];
}
}
}
}
ll query(const char s[]) { //返回有多少模式串在s中出现过
ll ans = 0;
int cur = 0;
for(int i = 0; s[i]; i++) {
cur = tr[cur][s[i] - 'a'];
if(s[i + 1] == '{')
for(int j = cur; j; j = fail[j]) {
ans += cont[j];
ans %= M;
}
}
return ans;
}
}
int main() {
IOS;
int n;
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> s;
getnext(s.c_str());
AC::insert(s.c_str());
tex += s + "{";
}
AC::build();
ll ans = AC::query(tex.c_str());
cout << (ans % M + M) % M << endl;
}