ac自动机

给出一些字符串,求对于每个字符串,在这些字符串里有多少子串

利用ac自动机+树链的并可以求解

//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

//#include <immintrin.h>
//#include <emmintrin.h>
#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=1e9+7;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=2.1e6;
int c[N][26],val[N],fail[N],cnt;

vector<int> ve[N];
void insert(char *s,int x)
{
  int len=strlen(s),now=0;
  for (int i=0;i<len;i++)
  {
    int v=s[i]-'a';
    if (!c[now][v]) c[now][v]=++cnt;
    now=c[now][v];
    ve[x].push_back(now);
  }
  val[now]++;
}
queue<int> q; 

void build()
{
  for (int i=0;i<26;i++)
    if (c[0][i]) fail[c[0][i]]=0,q.push(c[0][i]);
  while (!q.empty())
  {
    int u=q.front(); q.pop();
    for (int i=0;i<26;i++)
      if (c[u][i])
      {
          fail[c[u][i]]=c[fail[u]][i];
          q.push(c[u][i]);
      }  else c[u][i]=c[fail[u]][i];
  }  
}
char s[N];
int cnt2;
int dfn[N],dep[N];
vector<int> ve1[N];
int bz[22][N];
void dfs(int x,int y)
{
    if (x!=0)
    {
        val[x]+=val[y]; dep[x]=dep[y]+1;
    }
    bz[0][x]=y;
    rep(i,1,21) bz[i][x]=bz[i-1][bz[i-1][x]];
    dfn[x]=++cnt2;
    for (auto v:ve1[x])
    {
        dfs(v,x);
    }
}
int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    dep(i,21,0) if (dep[bz[i][x]]>=dep[y]) x=bz[i][x];
    if (x==y) return x;
    dep(i,21,0) if (bz[i][x]!=bz[i][y]) x=bz[i][x],y=bz[i][y];
    return bz[0][x];
}
bool cmp(int x,int y)
{
    return dfn[x]<dfn[y];
}
int main()
{
   freopen("11.in","r",stdin);
   freopen("1.out","w",stdout); 
   ios::sync_with_stdio(false);
   int n;
   cin>>n;
   rep(i,1,n)
   {
        cin>>s;
        insert(s,i);
   }
   build();
   for (int i=1;i<=cnt;i++) ve1[fail[i]].push_back(i);
   dfs(0,0);
   ll ans=0,ans2=0; 
   rep(i,1,n)
   {
        
        sort(ve[i].begin(),ve[i].end(),cmp);
        int n=ve[i].size();
        rep(j,0,n-2)
        {
            ans-=val[lca(ve[i][j],ve[i][j+1])];
        }
        rep(j,0,n-1) ans+=val[ve[i][j]];
   // cout<<ans<<endl;
   }
   cout<<ans-n<<endl;
   return 0;
}
View Code

 

posted @ 2021-06-08 17:11  尹吴潇  阅读(38)  评论(0编辑  收藏  举报