2020牛客暑期多校训练营(第二场)A All with Pairs AC自动机
题意
给出\(n\)个字符串\(s_1,s_2,\dots,s_n\),定义\(f(s,t)\)为字符串\(s\)和字符串\(t\)的最长公共前后缀(字符串\(s\)的前缀,字符串\(t\)的后缀)。
让你计算
\[\sum_{i=1}^{n} \sum_{j=1}^{n} f(s_i,s_j)^2~(mod~998244353)
\]
分析
将所有字符串\(s_i\)插入AC自动机中,构建\(fail\)树,因为\(fail\)指针的含义是\(fail[v]\)结点所代表的前缀串是\(v\)结点代表的前缀串能匹配到的最长后缀。所以字符串\(s_i\)的前缀能匹配到的后缀都在其所在的\(fail\)链上,根据\(fail\)指针反向建图,从根节点开始\(dfs\),\(dfs\)过程中对于每个\(s_i\)更新它的最长前缀,并统计答案,回溯时再撤销更新。
Code
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<sstream>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=998244353;
const int N=1e6+10;
const int inf=1e9;
int n;
char s[N];
struct ACtree{
int son[N][26],fail[N],len[N],end[N],c[N],tot;
vector<int>q[N],g[N];
stack<pii>st;
ll res,ans;
int newnode(){
for(int i=0;i<26;i++) son[tot][i]=0;
end[tot++]=0;
return tot-1;
}
void init(){
ans=res=tot=0;
newnode();
}
void ins(char s[],int x){
int rt=0,m=strlen(s);
for(int i=0;i<m;i++){
if(!son[rt][s[i]-'a']) son[rt][s[i]-'a']=newnode();
rt=son[rt][s[i]-'a'];
len[rt]=i+1;
q[rt].pb(x);
}
end[rt]++;
}
void gao(){
queue<int>q;
for(int i=0;i<26;i++) if(son[0][i]) fail[son[0][i]]=0,q.push(son[0][i]);
while(!q.empty()){
int u=q.front();q.pop();
for(int i=0;i<26;i++){
if(son[u][i]){
fail[son[u][i]]=son[fail[u]][i];
q.push(son[u][i]);
}else son[u][i]=son[fail[u]][i];
}
}
for(int i=1;i<tot;i++) g[fail[i]].pb(i);
}
void dfs(int u){
for(int x:q[u]){
st.push(mp(x,c[x]));
res=(res-c[x]+mod)%mod;
c[x]=1ll*len[u]*len[u]%mod;
res=(res+c[x])%mod;
}
ans=(ans+1ll*res*end[u]%mod)%mod;
for(int x:g[u]){
dfs(x);
}
for(int x:q[u]){
pii tmp=st.top();
st.pop();
res=(res-c[tmp.fi]+mod)%mod;
c[tmp.fi]=tmp.se;
res=(res+c[tmp.fi])%mod;
}
}
ll qy(){
dfs(0);
return ans;
}
}AC;
int main(){
//ios::sync_with_stdio(false);
//freopen("in","r",stdin);
AC.init();
scanf("%d",&n);
rep(i,1,n){
scanf("%s",s);
AC.ins(s,i);
}
AC.gao();
printf("%lld\n",AC.qy());
return 0;
}