ac自动机
给出一些字符串,求对于每个字符串,在这些字符串里有多少子串
利用ac自动机+树链的并可以求解
//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math") //#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native") //#include <immintrin.h> //#include <emmintrin.h> #include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=1e9+7; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; const int N=2.1e6; int c[N][26],val[N],fail[N],cnt; vector<int> ve[N]; void insert(char *s,int x) { int len=strlen(s),now=0; for (int i=0;i<len;i++) { int v=s[i]-'a'; if (!c[now][v]) c[now][v]=++cnt; now=c[now][v]; ve[x].push_back(now); } val[now]++; } queue<int> q; void build() { for (int i=0;i<26;i++) if (c[0][i]) fail[c[0][i]]=0,q.push(c[0][i]); while (!q.empty()) { int u=q.front(); q.pop(); for (int i=0;i<26;i++) if (c[u][i]) { fail[c[u][i]]=c[fail[u]][i]; q.push(c[u][i]); } else c[u][i]=c[fail[u]][i]; } } char s[N]; int cnt2; int dfn[N],dep[N]; vector<int> ve1[N]; int bz[22][N]; void dfs(int x,int y) { if (x!=0) { val[x]+=val[y]; dep[x]=dep[y]+1; } bz[0][x]=y; rep(i,1,21) bz[i][x]=bz[i-1][bz[i-1][x]]; dfn[x]=++cnt2; for (auto v:ve1[x]) { dfs(v,x); } } int lca(int x,int y) { if (dep[x]<dep[y]) swap(x,y); dep(i,21,0) if (dep[bz[i][x]]>=dep[y]) x=bz[i][x]; if (x==y) return x; dep(i,21,0) if (bz[i][x]!=bz[i][y]) x=bz[i][x],y=bz[i][y]; return bz[0][x]; } bool cmp(int x,int y) { return dfn[x]<dfn[y]; } int main() { freopen("11.in","r",stdin); freopen("1.out","w",stdout); ios::sync_with_stdio(false); int n; cin>>n; rep(i,1,n) { cin>>s; insert(s,i); } build(); for (int i=1;i<=cnt;i++) ve1[fail[i]].push_back(i); dfs(0,0); ll ans=0,ans2=0; rep(i,1,n) { sort(ve[i].begin(),ve[i].end(),cmp); int n=ve[i].size(); rep(j,0,n-2) { ans-=val[lca(ve[i][j],ve[i][j+1])]; } rep(j,0,n-1) ans+=val[ve[i][j]]; // cout<<ans<<endl; } cout<<ans-n<<endl; return 0; }