【XSY4180】串串游走(AC自动机,期望DP,高斯消元)

假瑞出的神仙题。原题 CFgym103119B

先把 \(T\) 去重。

考虑先用 \(O(nm\log k)\) 建出所有串的 AC 自动机。注意建 AC 自动机的时候,为了保证空间,假设当前点 \(u\) 没有的转移,我们不从 \(fail_u\) 那里复制;而对于当前点有的转移 \(v\),我们暴力跳 \(u\)\(fail\) 来更新 \(fail_v\)。这可以结合代码理解。

这样跳 \(fail\) 的时间复杂度是对的,你可以把它类似 kmp 中跳 \(next\) 指针的方式来理解:跳一次 \(fail\) 深度至少减一,而从父亲到儿子 \(fail\) 深度至多加一。所以总时间复杂度为 \(O(nm\log k)\),当然使用哈希实现能做到 \(O(nm)\)

那么原问题就转化为在 AC 自动机上的随机游走,直到走到某个串使得它包含 \(T\) 中的任意一个串为止。

由于 \(T\) 中的串长度都为 \(m\),所以停止的条件就转化为走到 \(T\) 中的任意一个串对应的节点为止(把这些节点称为结束节点)。

不妨设 \(g_{u,c}\) 表示 AC 自动机上点 \(u\) 往字符 \(c\) 方向的转移到哪(没有的话需要暴力跳 fail 来找)。

\(F(u)\) 表示从 \(u\) 走到任意一个结束节点的期望步数,那么有:

  • \(u\) 为一个结束节点,\(F(u)=0\)

  • \(u\) 不是一个结束节点,那么有:

    \[F(u)=1+\sum_{c=0}^{k-1}F(g_{u,c})p_c \]

然后就有 \(O(nm)\) 条方程,直接高斯消元即可。总时间复杂度是 \(O(|R|+(nm)^3)\) 的,仅可以得到 10pts。

观察到 \(k\leq 26\) 的部分分,注意到此时一条方程中 \(nm\) 个变量只有 \(k+1\) 个变量的系数非零,你感觉直接一整行一整行地高斯消元貌似太亏了。

考虑设置几个关键变量 \(x_1,x_2,\cdots,x_l\),然后利用 \(F(u)\) 的转移方程将其他的 \(F(u)\)\(x_1,x_2,\cdots,x_l\) 的线性组合加一个常数表示(即用这 \(l\) 个变量代入消掉其他的变量),最后再 \(O(l^3)\) 解出 \(x_1,x_2,\cdots,x_l\),再代入解出所有的 \(F(u)\)

我们需要设置恰当的 \(x_1,x_2,\cdots,x_l\)\(l\) 太小可能不能使所有的 \(F(u)\) 都能用 \(x_1,x_2,\cdots,x_l\) 的线性组合加一个常数表示,\(l\) 太大可能复杂度不够优。

我们可以这么考虑:先设 \(son(u)\) 表示 \(u\) 在 Trie 树上的儿子个数。对于每一个点 \(u\),若 \(son(u)>1\),那么我们把 \(u\) 在 Trie 树上的所有儿子 \(v\)\(F(v)\) 都设为关键变量。同时,我们把 \(F(rt)\) 也设为关键变量(\(rt\) 为根)。

先考虑为什么这样设置能使所有的 \(F(u)\) 都能用 \(x_1,x_2,\cdots,x_l\) 的线性组合加一个常数表示(下面简称为 “能被表示”:

我们考虑在 Trie 树上 bfs,假设当前点为 \(u\),假设之前 bfs 遍历到的点的 \(F\) 都能被表示,我们考虑 \(F(u)\) 是否能被表示。假设 \(u\) 的父亲为 \(f\)(若无父亲则 \(u\) 是根,此时 \(F(u)\) 显然能被表示)。

  • \(son(f)>1\),那么 \(F(u)\) 本来就是关键变量,显然它都能被表示。

  • \(son(f)=1\),那么 \(f\) 在 Trie 树上只有一个儿子 \(u\),假设 \(g_{f,t}=u\)。根据刚刚的转移方程:

    \[\begin{aligned} F(f)&=1+\sum_{c=0}^{k-1}F(g_{f,c})p_c\\ &=1+F(u)p_t+\sum_{c\neq t}F(g_{f,c})p_c\\ \end{aligned} \]

    注意到 \(f\) 在 Trie 树上只有一个儿子,所以对于任意 \(c\neq t\)\(g_{f,c}\)\(g_{f,c}\) 一定是从 \(f\)\(fail\) 树的某个祖先那里向 \(c\) 转移得到的,于是必然有 \(len(g_{f,c})\leq len(f)<len(u)\)。又由于我们是按 bfs 的顺序,\(g_{f,c}\) 一定比 \(u\) 先被遍历到,所以遍历到 \(u\)\(F(g_{f,c})\) 肯定已经被表示。

    于是又根据上面的等式移项得:

    \[F(u)=-\dfrac{1-F(f)+\sum_{c\neq t}F(g_{f,c})p_c}{p_t} \]

    于是 \(F(u)\) 能被表示。

然后是考虑 \(l\) 的级别:\(F(u)\) 是关键变量当且仅当 \(son(f)>1\),考虑 Trie 树上有多少个这样的 \(u\)

注意到 Trie 树上只有 \(n\) 个叶子,所以显然这样的点只有 \(O(n)\) 个。

然后高斯消元,理论上本来 \(nm\) 条方程 \(nm\) 个变量中,对于某一条方程 \(F(f)=...\),要么在 \(son(f)=1\) 的情况中用来去表示 \(F(u)\) 了(可以理解为消掉一个变量 \(F(u)\)),要么你没用过,这些没用过的方程理论上应该剩下 \(l\) 条。那么具体剩下的是哪几条方程呢?显然是 \(son(f)>1\)\(F(f)\) 的方程 \(F(f)=1+\sum\limits_{c=0}^{k-1}F(g_{f,c})p_c\)\(son(f)=0\)\(F(f)\) 的方程 \(F(f)=0\)(注意方程中的 \(F(f)\)\(F(g_{f,c})\) 全部都用 \(x_1,x_2,\cdots,x_l\) 表示了)。直接用这几条方程高斯消元即可。

总时间复杂度 \(O(n^2mk+n^3)\)

\(k\) 很大时怎么办?一个瓶颈是在 \(son(f)=1\) 时:

\[F(u)=-\dfrac{1-F(f)+\sum_{c\neq t}F(g_{f,c})p_c}{p_t} \]

这里需要枚举所有的 \(c\) 并每次 \(O(n)\) 累加。

\(f=rt\),那么 \(F(u)=F(f)-\dfrac{1}{p_t}\)

\(f\neq rt\),解决这个问题就需要利用 \(fail\) 的性质了。我们刚刚说到,对于 \(c\neq t\)\(g_{f,c}=g_{fail_f,c}\),于是:

\[\begin{aligned} F(u)&=-\dfrac{1-F(f)+\sum_{c}F(g_{fail_f,c})p_c-F(g_{fail_f,t})p_t}{p_t}\\ &=-\dfrac{1-F(f)+\bigg(F(fail_f)-1\bigg)-F(g_{fail_f,t})p_t}{p_t}\\ &=\dfrac{F(f)-F(fail_f)}{p_t}+F(g_{fail_f,t}) \end{aligned} \]

同时因为 \(g_{f,t}=u\),所以有 \(g_{fail_f,t}=fail_u\),于是:

\[F(u)=\dfrac{F(f)-F(fail_f)}{p_t}+F(fail_u) \]

于是就能 \(O(n)\) 求出 \(F(u)\) 了。

另一个瓶颈是在列 \(son(f)>1\) 的方程时:

\[F(f)=1+\sum\limits_{c=0}^{k-1}F(g_{f,c})p_c \]

这里也需要枚举所有的 \(c\) 并每次 \(O(n)\) 累加。

我们把 \(f\) 在 Trie 树上的儿子的 \(c\) 的集合记为 \(T\),于是:

\[F(f)=1+\sum\limits_{c\in T}F(g_{f,c})p_c+\sum_{c\not \in T}F(g_{f,c})p_c \]

显然对于 \(c\in T\) 的情况,由于 \(F(g_{f,c})\) 是关键变量,\(x_1,x_2,\cdots,x_l\) 中只有一项的系数非零,所以很好处理。

对于 \(c\not\in T\) 的情况:若 \(f=rt\),则 \(g_{f,c}=f\);若 \(f\neq rt\),我们沿用刚刚的做法:

\[\begin{aligned} F(f)&=1+\sum\limits_{c\in T}F(g_{f,c})p_c+\sum_{c\not \in T}F(g_{f,c})p_c\\ &=1+\sum_{c\in T}F(g_{f,c})p_c+\sum_cF(g_{fail_f,c})p_c-\sum_{c\in T}F(g_{fail_f,c})p_c\\ &=1+\sum_{c\in T}F(g_{f,c})p_c+\bigg(F(fail_f)-1\bigg)-\sum_{c\in T}F(g_{fail_f,c})p_c\\ &=\sum_{c\in T}F(g_{f,c})p_c+F(fail_f)-\sum_{c\in T}F(g_{fail_f,c})p_c\\ &=\sum_{c\in T}F(g_{f,c})p_c+F(fail_f)-\sum_{c\in T}F(fail_{g_{f,c}})p_c\\ \end{aligned} \]

然后枚举 \(\sum_{c\in T}\) 的次数就是关键变量的个数,为 \(O(n)\),所以复杂度就对了。

总时间复杂度 \(O(nm\log k+n^2m+n^3+|R|)\)

#include<bits/stdc++.h>
 
#define N 110
#define NM 100010
 
using namespace std;
 
namespace modular
{
    const int mod=1000000007;
    inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
    inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
    inline int mul(int x,int y){return 1ll*x*y%mod;}
    inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
    inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
    inline void Mul(int &x,int y){x=1ll*x*y%mod;}
}using namespace modular;
 
inline int poww(int a,int b)
{
    int ans=1;
    while(b)
    {
        if(b&1) ans=mul(ans,a);
        a=mul(a,a);
        b>>=1;
    }
    return ans;
}
 
inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch^'0');
        ch=getchar();
    }
    return x*f;
}
 
int n,m,k,R,p[NM],invp[NM];
int node,fa[NM],fail[NM];
int nn,id[NM];
int x[NM],dp[NM];
 
struct Vector
{
    int a[N<<1];
    void clear(){for(int i=1;i<=nn+1;i++)a[i]=0;}
}F[NM];
 
Vector operator + (Vector a,Vector b)
{
    for(int i=1;i<=nn+1;i++) Add(a.a[i],b.a[i]);
    return a;
}
 
Vector operator - (Vector a,Vector b)
{
    for(int i=1;i<=nn+1;i++) Dec(a.a[i],b.a[i]);
    return a;
}
 
Vector operator * (int a,Vector b)
{
    for(int i=1;i<=nn+1;i++) Mul(b.a[i],a);
    return b;
}
 
map<int,int>ch[NM];
 
void insert(int id)
{
    int u=0;
    for(int i=1;i<=m;i++)
    {
        int v=read();
        if(!ch[u][v]) fa[ch[u][v]=++node]=u;
        u=ch[u][v];
    }
}
 
void getfail()
{
    queue<int>q;
    for(auto it:ch[0]) q.push(it.second);
    while(!q.empty())
    {
        int u=q.front();
        q.pop();
        for(auto it:ch[u])
        {
            int c=it.first,v=it.second;
            int now=fail[u];
            while(now&&ch[now].find(c)==ch[now].end()) now=fail[now];
            if(ch[now].find(c)!=ch[now].end()) fail[v]=ch[now][c];
            q.push(v);
        }
    }
}
 
void findkey()
{
    id[0]=++nn;
    for(int i=1;i<=node;i++)
    {
        int f=fa[i];
        if((int)ch[f].size()!=1)
            id[i]=++nn;
    }
}
 
void DP()
{
    queue<int>q;
    for(auto it:ch[0]) q.push(it.second);
    F[0].a[id[0]]=1;
    while(!q.empty())
    {
        int u=q.front();
        q.pop();
        for(auto it:ch[u]) q.push(it.second);
        int f=fa[u];
        if((int)ch[f].size()!=1)
        {
            F[u].a[id[u]]=1;
            continue;
        }
        int t=(*ch[f].begin()).first;
        if(f) F[u]=(invp[t]*(F[f]-F[fail[f]]))+F[fail[u]];
        else F[u]=F[f],Dec(F[u].a[nn+1],invp[t]);
    }
}
 
int equ,a[N<<1][N<<1];
 
void getEqu()
{
    Vector now;
    if((int)ch[0].size()!=1)
    {
        equ++;
        now.clear();
        Add(now.a[nn+1],1);
        int sump=0;
        for(int i=0;i<k;i++) Add(sump,p[i]);
        for(auto it:ch[0])
        {
            int c=it.first,v=it.second;
            sump=dec(sump,p[c]);
            Add(now.a[id[v]],p[c]);
        }
        Add(now.a[id[0]],dec(sump,1));
        now.a[nn+1]=dec(0,now.a[nn+1]);
        for(int i=1;i<=nn+1;i++) a[equ][i]=now.a[i];
    }
    for(int i=1;i<=node;i++)
    {
        if((int)ch[i].size()!=1)
        {
            equ++;
            if((int)ch[i].size()>1)
            {
                now=F[fail[i]]-F[i];
                for(auto it:ch[i])
                {
                    int c=it.first,v=it.second;
                    now=now+p[c]*(F[v]-F[fail[v]]);
                }
            }
            else now=F[i];
            now.a[nn+1]=dec(0,now.a[nn+1]);
            for(int i=1;i<=nn+1;i++) a[equ][i]=now.a[i];
        }
    }
    assert(equ==nn);
}
 
void Gauss()
{
    for(int i=1;i<=nn;i++)
    {
        int p=i;
        for(int j=i+1;j<=nn;j++)
            if(a[j][i]) p=j;
        if(i!=p) swap(a[i],a[p]);
        int inv=poww(a[i][i],mod-2);
        for(int j=i+1;j<=nn;j++)
        {
            int div=mul(a[j][i],inv);
            for(int k=i;k<=nn+1;k++)
                Dec(a[j][k],mul(a[i][k],div));
        }
    }
    for(int i=nn;i>=1;i--)
    {
        x[i]=a[i][nn+1];
        for(int j=nn;j>i;j--)
            Dec(x[i],mul(a[i][j],x[j]));
        Mul(x[i],poww(a[i][i],mod-2));
    }
}
 
void calc()
{
    for(int i=0;i<=node;i++)
    {
        for(int j=1;j<=nn;j++)
            Add(dp[i],mul(F[i].a[j],x[j]));
        Add(dp[i],F[i].a[nn+1]);
    }
}
 
int main()
{
//  freopen("string2.in","r",stdin);
//  freopen("string2_my.out","w",stdout);
    n=read(),m=read(),k=read();
    for(int i=0;i<k;i++)
        p[i]=read(),invp[i]=poww(p[i],mod-2);
    for(int i=1;i<=n;i++) insert(i);
    getfail();
    findkey();
    DP();
    getEqu();
    Gauss();
    calc();
    R=read();
    int u=0;
    bool flag=0;
    for(int i=1;i<=R;i++)
    {
        int c=read();
        if(flag)
        {
            printf("%d\n",i);
            continue;
        }
        while(u&&ch[u].find(c)==ch[u].end()) u=fail[u];
        if(ch[u].find(c)!=ch[u].end()) u=ch[u][c];
        if(!(int)ch[u].size()) flag=1;
        printf("%d\n",add(i,dp[u]));
    }
    return 0;
}
posted @ 2022-10-31 07:23  ez_lcw  阅读(43)  评论(0编辑  收藏  举报