SAM
这是广义SAM
#include<bits/stdc++.h>
#define rep(x, L, R) for(int x = (L), _x = (R); x <= _x; x++)
using namespace std;
const int N = 2e6 + 10, md = 998244353, T = N << 1, iv2 = (md + 1) / 2;
int fa[T], ch[T][26], dep[T], cnt = 0;
int n, p[T], sum[T][2];
int ans[3], res;
char s[N];
vector<int> g[T];
vector<int> str[N];
int add(int x) {return x >= md ? x - md : x;}
int sub(int x) {return x < 0 ? x + md : x;}
void Add(int &x, int y) {
if((x += y) >= md && (x -= md));
}
void Sub(int &x, int y) {
if((x -= y) < 0 && (x += md));
}
void adde(int u, int v) {
g[u].push_back(v);
}
int newd(int d) {
int u = ++cnt;
dep[u] = d;
memset(ch[u], 0, sizeof(ch[u]));
fa[u] = sum[u][0] = sum[u][1] = 0;
g[u].clear();
return u;
}
void Clear() {
cnt = 0;
return;
}
int ins(int p, int c) {
if(ch[p][c]) {
int q = ch[p][c];
if(dep[q] == dep[p] + 1) return q;
int np = newd(dep[p] + 1);
memcpy(ch[np], ch[q], sizeof(ch[q]));
fa[np] = fa[q], fa[q] = np;
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = np;
return np;
}
int np = newd(dep[p] + 1);
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) fa[np] = 1;
else {
int q = ch[p][c];
if(dep[q] == dep[p] + 1) fa[np] = q;
else {
int nq = newd(dep[p] + 1);
memcpy(ch[nq], ch[q], sizeof(ch[q]));
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}
}
return np;
}
void Dfs(int u, int fa) {
rep(i, 0, 1) {
rep(j, 0, 1) {
Add(ans[i + j], 1ll * sum[u][i] * sum[u][j] % md * dep[u] % md);
}
}
for(auto v : g[u]) {
if(v == fa) continue;
Dfs(v, u);
rep(i, 0, 1) {
rep(j, 0, 1) {
Add(ans[i + j], 1ll * sum[u][i] * sum[v][j] % md * dep[u] % md * 2 % md);
}
}
rep(i, 0, 1) Add(sum[u][i], sum[v][i]);
}
return;
}
void calc() {
Clear();
memset(ans, 0, sizeof(ans));
newd(0);
for(int i = 1; i <= n; i++) {
int len = str[i].size();
int rt = 1;
for(int j = 0; j < len; j++) {
rt = ins(rt, str[i][j]);
Add(sum[rt][1], p[i]);
}
rt = 1;
for(int j = len - 1; j >= 0; j--) {
rt = ins(rt, str[i][j]);
Add(sum[rt][0], sub(1 - p[i]));
}
}
for(int i = 2; i <= cnt; i++) adde(i, fa[i]), adde(fa[i], i);
Dfs(1, 0);
Add(res, add(add(ans[0] + ans[1]) + ans[2]));
return;
}
void sub(vector<int> &str, int p) {
Clear();
memset(ans, 0, sizeof(ans));
int len = str.size();
int rt = newd(0);
for(int i = 0; i < len; i++) {
rt = ins(rt, str[i]);
Add(sum[rt][1], 1);
}
rt = 1;
for(int i = len - 1; i >= 0; i--) {
rt = ins(rt, str[i]);
Add(sum[rt][0], 1);
}
for(int i = 2; i <= cnt; i++) adde(i, fa[i]), adde(fa[i], i);
Dfs(1, 0);
Sub(res, 1ll * ans[0] * sub(1 - p) % md * sub(1 - p) % md);
Sub(res, 1ll * ans[2] * p % md * p % md);
Sub(res, 1ll * ans[1] * p % md * sub(1 - p) % md);
Add(res, 1ll * ans[0] * sub(1 - p) % md);
Add(res, 1ll * ans[2] * p % md);
return ;
}
int main() {
// freopen("in.in", "r", stdin);
scanf("%d", &n);
for(int i = 1; i <= n; i++) scanf("%d", &p[i]);
for(int i = 1; i <= n; i++) {
scanf("%s", s + 1);
int len = strlen(s + 1);
for(int j = 1; j <= len; j++) str[i].push_back(s[j] - 'a');
}
calc();
for(int i = 1; i <= n; i++) sub(str[i], p[i]);
printf("%d\n", res);
return 0;
}
一定要记住memcpy和copy的区别在于memcpy的目标地址放在前面,copy放在后面