AC自动机

AC自动机

AC 自动机可以理解为在多个串上的 KMP,利用 Trie 树来维护这些串,nxt 数组变为 fail 指针。

fail 指针的构造思想如下:

考虑 Trie 树中当前的节点 \(u\)\(u\) 的父节点是 \(p\)\(p\) 通过字符 \(c\) 的边指向 \(u\),即 \(\text{trie}[p,c]=u\)。假设深度小于 \(u\) 的所有节点的 \(\text{fail}\) 指针都已求得。

1.如果 \(\text{trie}[\text{fail}[p],c]\) 存在,则让 \(u\) 的 fail 指针指向 trie[fail[p], c]。相当于在 \(p\)\(\text{fail}[p]\) 后面加一个字符 \(c\),分别对应 \(u\)\(\text{fail}[u]\)

2.如果 trie[fail[p], c] 不存在,那么我们继续找到 trie[fail[fail[p]], c],重复 1 的判断过程,一直跳 fail 指针直到根节点。

3.如果真的没有,就让 fail 指针指向根节点。

可以发现,AC 自动机与 KMP 是非常相似的。

这是基本思想,具体实现我们可以直接构造 Trie 图以进行多串匹配。

实现

普通实现

【模板】AC自动机(简单版)(luoguP3808)

给定 \(n\) 个模式串 \(s_i\) 和一个文本串 \(t\),求有多少个不同的模式串在文本串里出现过,两个模式串不同当且仅当它们编号不同。

\(1\le n\le10^6\)\(1\le|t|\le10^6\)\(1\le\sum_{i=1}^n|s_i|\le10^6\)


1 在 Trie 树上插入各串

void Insert(char *s) {
    int n = strlen(s);
    int pos = 0;
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        if(tr[pos][tmp] == 0)
            tr[pos][tmp] = ++ct;
        pos = tr[pos][tmp];
    }
    val[pos] ++;
}

2 fail指针的构建

按照上面的思路进行构建

void build() {
    queue<int> q;
    memset(fail, 0, sizeof(fail));
    for(int i = 0; i < 26; i ++)
        if(tr[0][i])	// 0 是根节点,后面 while(tmp && ...) 其实前者是 tmp != rt
            q.push(tr[0][i]);	// 从根节点的子节点开始,类似于 KMP 从 2 开始
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(tr[x][i]) {	// 匹配现在的
                int tmp = fail[x];
                while(tmp && tr[tmp][i] == 0)	// 没匹配上,往回跳
                    tmp = fail[tmp];
                if(tr[tmp][i])	// fail 指针那边的也有 i 这个字符,匹配上了
                    tmp = tr[tmp][i];
                fail[tr[x][i]] = tmp;
                q.push(tr[x][i]);
            }
    }
}

3 多模式匹配

多个模式串对文本串的匹配。

如果能匹配,就匹配,不能匹配,就跳 fail,注意可能多次匹配,所以要去掉重复匹配的贡献。

int query(char *s) {
    int pos = 0, ret = 0;
    int n = strlen(s);
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        while(pos && tr[pos][tmp] == 0)
            pos = fail[pos];
        if(tr[pos][tmp])
            pos = tr[pos][tmp];
        ret += val[pos];
        val[pos] = 0;
    }
    return ret;
}

完整代码

#include<cstdio>
#include<cstring>
#include<queue>

using namespace std;

const int N = 1000010;

int n;
char s[N];
int ct, tr[N][26], val[N], fail[N];

void Insert(char *s) {
    int n = strlen(s);
    int pos = 0;
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        if(tr[pos][tmp] == 0)
            tr[pos][tmp] = ++ct;
        pos = tr[pos][tmp];
    }
    val[pos] ++;
}

void build() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(tr[0][i]) {
            fail[tr[0][i]] = 0;
            q.push(tr[0][i]);
        }
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(tr[x][i]) {
                int tmp = fail[x];
                while(tmp && tr[tmp][i] == 0)
                    tmp = fail[tmp];
                if(tr[tmp][i])
                    tmp = tr[tmp][i];
                fail[tr[x][i]] = tmp;
                q.push(tr[x][i]);
            }
    }
}

int query(char *s) {
    int pos = 0, ret = 0;
    int n = strlen(s);
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        while(pos && tr[pos][tmp] == 0)
            pos = fail[pos];
        if(tr[pos][tmp])
            pos = tr[pos][tmp];
        ret += val[pos];
        val[pos] = 0;
    }
    return ret;
}

int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++) {
        scanf("%s", s+1);
        Insert(s+1);
    }
    build();
    scanf("%s", s+1);
    printf("%d\n", query(s+1));
    return 0;
}

Trie图的构建

建出真实的边,使 Trie 树变为一张图。

具体而言,只要把儿子记录上即可。

复杂度 \(O(|S|\times n)\)\(|S|\) 是字符集大小。

void build() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(t[rt][i]) {
            q.push(t[rt][i]);
            fail[t[0][i]] = rt;
        }
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(t[x][i]) {
                fail[t[x][i]] = t[fail[x]][i];
                q.push(t[x][i]);
            } else
                t[x][i] = t[fail[x]][i];
    }
}

解释:

对于 else 部分,相当于直接把这个不存在的儿子接到其失配指针的这个儿子,如果失配指针也没有这个儿子呢?那么它一定也通过 else 接上了可能的儿子,所以可以保证接上的恰好是一个最优的位置。

对于 if 部分,类似地,假如它的父亲(\(x\))有 \(i\) 这个儿子,相当于直接匹配,那么显然是对的,如果没有,等同于 else 部分的解释,它一定指向了一个最优的位置。

int query(char *s) {
    int pos = 0, ret = 0;
    int n = strlen(s);
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        pos = tr[pos][tmp];
        for(int j = pos; j > 0 && val[j] != -1; j = fail[j]) {  // 由于打标记后 j = fail[j] 仍然会进行下去,所以一直到根节点全都打了标记,所以如果再次走到这个位置可以直接退出
            ret += val[j];
            val[j] = -1;
        }
    }
    return ret;
}

完整代码

#include<cstdio>
#include<cstring>
#include<queue>

using namespace std;

const int N = 1000010;

int n;
char s[N];
int ct, tr[N][26], val[N], fail[N];

void Insert(char *s) {
    int n = strlen(s);
    int pos = 0;
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        if(tr[pos][tmp] == 0)
            tr[pos][tmp] = ++ct;
        pos = tr[pos][tmp];
    }
    val[pos] ++;
}

void build() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(tr[0][i])
            q.push(tr[0][i]);
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(tr[x][i]) {
                fail[tr[x][i]] = tr[fail[x]][i];
                q.push(tr[x][i]);
            } else
                tr[x][i] = tr[fail[x]][i];
    }
}

int query(char *s) {
    int pos = 0, ret = 0;
    int n = strlen(s);
    for(int i = 0; i < n; i ++) {
        int tmp = s[i] - 'a';
        pos = tr[pos][tmp];
        for(int j = pos; j > 0 && val[j] != -1; j = fail[j]) {  // 由于打标记后 j = fail[j] 仍然会进行下去,所以一直到根节点全都打了标记,所以如果再次走到这个位置可以直接退出
            ret += val[j];
            val[j] = -1;
        }
    }
    return ret;
}

int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++) {
        scanf("%s", s+1);
        Insert(s+1);
    }
    build();
    scanf("%s", s+1);
    printf("%d\n", query(s+1));
    return 0;
}

另外一个模板

【模板】AC自动机(加强版)(luoguP3796)

\(N\) 个由小写字母组成的模式串以及一个文本串 \(T\),找出哪些模式串在文本串 \(T\) 中出现的次数最多。

\(1\le N\le 150\),单个模式串长度 \(\le70\),文本串长度 \(\le10^6\)


由于 Trie 数每个节点表示的字符串具有唯一性,并且 AC 自动机每到达一个位置就表示和当前位置表示的字符串匹配了,我们可以这么做:

记录每个模式串结尾位置,匹配时每到一个位置就累加这个位置的模式串出现次数。

#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>

using namespace std;

const int N = 20010, M = 1000010, H = 80, P = 160;

int n;
char s[M];
char str[P][H];
int ct, t[N][26], fail[N], ed[N];
int tim[P]; // 各个串的匹配次数
int top, stk[P], maxtim;

void Insert(char *s, int rk) {
    int pos = 0;
    int len = strlen(s);
    for(int i = 0; i < len; i ++) {
        int tmp = s[i] - 'a';
        if(t[pos][tmp] == 0)
            t[pos][tmp] = ++ct;
        pos = t[pos][tmp];
    }
    ed[pos] = rk;
}

void build() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(t[0][i])
            q.push(t[0][i]);
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(t[x][i]) {
                fail[t[x][i]] = t[fail[x]][i];
                q.push(t[x][i]);
            } else
                t[x][i] = t[fail[x]][i];
    }
}

void query(char *s) {
    int pos = 0;
    int len = strlen(s);
    for(int i = 0; i < len; i ++) {
        int tmp = s[i] - 'a';
        pos = t[pos][tmp];
        for(int j = pos; j; j = fail[j])
            tim[ed[j]] ++;
    }
}

void solv() {
    for(int i = 1; i <= n; i ++) {
        scanf("%s", str[i]+1);
        Insert(str[i]+1, i);
    }
    build();
    scanf("%s", s+1);
    query(s+1);
    maxtim = 0;
    top = 0;
    for(int i = 1; i <= n; i ++)
        if(tim[i] > maxtim) {
            maxtim = tim[i];
            top = 0;
            stk[++top] = i;
        } else if(tim[i] == maxtim)
            stk[++top] = i;
    printf("%d\n", maxtim);
    for(int i = 1; i <= top; i ++)
        printf("%s\n", str[stk[i]] + 1);
}

void clea() {
    memset(ed, 0, sizeof(ed));
    memset(fail, 0, sizeof(fail));
    ct = 0;
    memset(t, 0, sizeof(t));
    memset(tim, 0, sizeof(tim));
}

int main() {
    while(scanf("%d", &n)) {
        if(n == 0)
            break;
        solv();
        clea();
    }
    return 0;
}

Fail树

Fail 指针构成了一个树形结构。

1.除了根节点都有 Fail 指针。

2.各个节点都能跳到根。

3.所以有 \(N\) 个点 \(N-1\) 条边,且连通,所以是树。

【模板】AC自动机(二次加强版)(luoguP5357)

给定一个文本串 \(S\)\(n\) 个模式串 \(T_{1\sim n}\),请你分别求出每个模式串 \(T_i\)\(S\) 中出现的次数。

不保证任意两个模式串不同。

\(1\le n\le2\times10^5\)\(\sum_{i=1}^n|T_i|\le2\times10^5\)\(|S|\le2\times10^6\)


观察一下代码片段:

for(int i = 0; i < len; i ++) {
    int tmp = s[i] - 'a';
    pos = t[pos][tmp];
    for(int j = pos; j; j = fail[j])
        tim[ed[j]] ++;
}

可以发现一个问题,就是 j 循环的部分,跳 fail[j] 实际上复杂度是不确定的,深度可能很深,复杂度会退化。

但是我们发现其实每次转移是类似的,可以建立出 fail 树然后通过遍历一遍 fail 树将所有转移都做了。

具体地,j 恰好使一个链都 ++,于是我们可以在底端打一个标记,然后构建 fail 树后在树上做 DP 向上传递并累加。

#include<cstdio>
#include<vector>
#include<queue>
#include<cstring>

using namespace std;

const int N = 200010, M = 2000010;

int n;
char s[M];
int ct, t[N][26], fail[N];
vector<int> ed[N];
int ans[N];
int ctb, hd[N], ver[N<<1], nxt[N<<1];
int f[N];   // 树形 DP

void add(int u, int v) {
    ver[++ctb] = v;
    nxt[ctb] = hd[u];
    hd[u] = ctb;
}

void Insert(char *str, int rk) {
    int pos = 0;
    int len = strlen(str);
    for(int i = 0; i < len; i ++) {
        int tmp = str[i] - 'a';
        if(t[pos][tmp] == 0)
            t[pos][tmp] = ++ct;
        pos = t[pos][tmp];
    }
    ed[pos].push_back(rk);
}

void build() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(t[0][i])
            q.push(t[0][i]);
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(t[x][i]) {
                fail[t[x][i]] = t[fail[x]][i];
                q.push(t[x][i]);
            } else
                t[x][i] = t[fail[x]][i];
    }
}

void query(char *str) {
    int pos = 0;
    int len = strlen(str);
    for(int i = 0; i < len; i ++) {
        int tmp = s[i] - 'a';
        pos = t[pos][tmp];
        f[pos] ++;
    }
}

void dfs(int x, int fa) {
    for(int i = hd[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa)
            continue;
        dfs(y, x);
        f[x] += f[y];
    }
    int tmpsiz = ed[x].size();
    for(int i = 0; i < tmpsiz; i ++)
        ans[ed[x][i]] += f[x];
}

int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++) {
        scanf("%s", s);
        Insert(s, i);
    }
    build();
    scanf("%s", s);
    query(s);
    for(int i = 1; i <= ct; i ++) {
        add(i, fail[i]);
        add(fail[i], i);
    }
    dfs(0, -1);
    for(int i = 1; i <= n; i ++)
        printf("%d\n", ans[i]);
    return 0;
}

例题

阿狸的打字机 (NOI2011) (luoguP2414)

有一个打字机,有 \(28\) 个按键,分别是 \(26\) 个小写英文字母和 BP 两个字母,打字机是这样工作的:

  • 输入小写字母,会把这个字母加在凹槽最后。
  • 按下 B,凹槽中的最后一个字母会消失
  • 按下 P,会打印出凹槽中的字母,并换行,但凹槽中的字母不会消失

给定一个字符串,表示按键情况。

然后把打印出来的字符串按照 \(1\sim n\) 编号,有 \(m\) 组询问,每次给定一个 \((x,y)\) 表示询问第 \(x\) 个打印的字符串在第 \(y\) 个打印的字符串中出现了多少次。

\(1\le n\le10^5\)\(1\le m\le10^5\),表示按键情况的字符串长度 \(\le10^5\)


最暴力的做法自然是跑 \(m\) 遍 KMP,复杂度 \(O(nm)\),是不能接受的。

本题的特点就在于它是由打字机打出来的,由于按键次数至多 \(10^5\) 次,所以打出来的字符串相差必然都比较小。

如果放到 Trie 树上,必然不超过 \(10^5\) 个节点。

我们想到构建 AC 自动机。

类似于 KMP,如果一个 fail 指针是从 \(x\) 指向 \(y\) 的,那么字符串 \(y\) 在字符串 \(x\) 中出现了一次。

对于一个字符串,从根到叶子,一路上每个节点都跳 fail 对应出来的若干个字符串全都在这个字符串内,并且不会漏,当然注意还要加上恰好就在这个串里的,即一路过来的这些子串。

于是我们可以在 Trie 树上 DFS,用一个全局桶,到每个节点跳 fail 把该加的都加进来。

然后光荣地 TLE 了。

跳 fail 复杂度是不对的,但是我们已经找到了策略,如何保证复杂度呢?

然后这时发现了一个问题,就是构建 AC 自动机太慢了,如果模拟的话,可能会插入许多许多次相同但很长的字符串,导致爆炸,于是我们需要特化本题的插入函数。

具体地,仍然是模拟,但是不从头插了,而是在 Trie 图上走,就可以保证复杂度。

然后我们就可以继续思考如何解决不可跳 fail 的问题。

分析可以发现,我们复杂度瓶颈在于一路上把各种串都插到桶里了,但是我们要查询的却比较少,于是我们考虑一对查询 \((x,y)\) 的关系。

按照上面的想法,就是 Trie 树上从根到 \(y\) 一路上的点全插进来,fail 指针对应在 fail 树上的一条链全插进来,然后判断指向 \(x\) 的个数,我们反向思考可以发现,只有 fail 树上 \(x\) 的子树内的才能指向它,而且发现是一条链上全有贡献,所以子树内全都有贡献,于是我们每次不要跳 fail 到根,而是仅打一个标记,然后查 \(x\) 只需要查 \(x\) 子树内的标记数目。

然后这个东西可以数据结构优化,子树可以通过 DFS 序化为区间,然后就是单点修改区间查询。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>

using namespace std;

const int N = 100010;
typedef pair<int, int> pii;

int m, ct, tr[N][26], tmptr[N][26], fail[N];
char s[N];
int ctb, top, stk[N];
vector<int> strk[N];    // 各个点对应的串的编号们
vector<pii> qry[N];    // 各个点的询问,第二关键字为询问序号
int p[N];   // 第 i 个串在 Trie 中的位置
int cnt[N]; // 全局桶,记录各个串出现次数,为了处理恰好就在这个串里的情况
int ans[N];
int ctc, hd[N], ver[N<<1], nxt[N<<1];
int ctd, dfn[N], siz[N];
int t[N];   // BIT

void addedge(int u, int v) {
    ver[++ctc] = v;
    nxt[ctc] = hd[u];
    hd[u] = ctc;
}

void Insert(char *str) {
    int pos = 0;
    int len = strlen(str);
    for(int i = 0; i < len; i ++) {
        if(str[i] >= 'a' && str[i] <= 'z') {
            int tmp = str[i] - 'a';
            if(tr[pos][tmp] == 0)
                tr[pos][tmp] = ++ct;
            stk[++top] = tr[pos][tmp];
            pos = tr[pos][tmp];
        } else if(str[i] == 'B') {
            top --;
            pos = stk[top];
        } else if(str[i] == 'P') {
            ++ ctb;
            strk[pos].push_back(ctb);
            p[ctb] = pos;
        }
    }
}

void build() {
    memcpy(tmptr, tr, sizeof(tr));
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(tmptr[0][i]) {
            q.push(tmptr[0][i]);
            fail[tmptr[0][i]] = 0;
        }
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(tmptr[x][i]) {
                fail[tmptr[x][i]] = tmptr[fail[x]][i];
                q.push(tmptr[x][i]);
            } else
                tmptr[x][i] = tmptr[fail[x]][i];    // 这种方法会影响 Trie 树结构,所以我们需要新建一个副本来构建 fail 指针
    }
}

void modify(int pos, int dlt) {
    for(; pos <= ct+1; pos += pos & (-pos)) // 注意有 ct+1 个点!
        t[pos] += dlt;
}

int subquery(int pos) {
    int ret = 0;
    for(; pos; pos -= pos & (-pos))
        ret += t[pos];
    return ret;
}

int query(int l, int r) {
    return subquery(r) - subquery(l-1);
}

void dfsa(int x, int fa) {  // 给 fail 树标 dfn 标记并求出辅助的一些量
    siz[x] = 1;
    dfn[x] = ++ctd;
    for(int i = hd[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa)
            continue;
        dfsa(y, x);
        siz[x] += siz[y];
    }
}

void dfsb(int x) {
    int tmpsiz = strk[x].size();
    for(int i = 0; i < tmpsiz; i ++)
    	cnt[strk[x][i]] ++; // 把当前串插入桶
    modify(dfn[fail[x]], 1);    // 打标记
    tmpsiz = qry[x].size();
    for(int i = 0; i < tmpsiz; i ++) {
        pii ttmp = qry[x][i];
        ans[ttmp.second] = cnt[ttmp.first] + query(dfn[p[ttmp.first]], dfn[p[ttmp.first]] + siz[p[ttmp.first]] - 1);    // 得到答案
    }
    for(int i = 0; i < 26; i ++)
        if(tr[x][i])
            dfsb(tr[x][i]); // dfs
    tmpsiz = strk[x].size();   // 回溯时撤回
    for(int i = 0; i < tmpsiz; i ++)
        cnt[strk[x][i]] --;
    modify(dfn[fail[x]], -1);
}

int main() {
    scanf("%s", s+1);
    Insert(s+1);
    build();
    scanf("%d", &m);
    for(int i = 1, tx, ty; i <= m; i ++) {
        scanf("%d%d", &tx, &ty);
        qry[p[ty]].push_back(make_pair(tx, i));
    }
    for(int i = 1; i <= ct; i ++) {
        addedge(i, fail[i]);
        addedge(fail[i], i);
    }
    dfsa(0, -1);
    dfsb(0);
    for(int i = 1; i <= m; i ++)
        printf("%d\n", ans[i]);
    return 0;
}


魔法咒语 (BJOI2017) (luoguP3715)

给定 \(n\) 个基本词汇,\(m\) 个忌讳词语,它们都是字符串。求满足下列条件的字符串的个数:

1.长度等于 \(L\)

2.可以被分割为若干个基本词汇。

3.不存在任何一个禁忌词语在字符串中出现过。

这里字符串不同的条件比较特殊:

把基本词汇标号 \(1\sim n\),把字符串分割为若干个基本词汇,设 \(i\) 号基本词汇的出现次数为 \(c_i\),两个字符串不同当且仅当存在一个 \(c_i\) 不相等。

而书写形式相同也可能是两个不同的字符串。

答案对 \(10^9+7\) 取模。

\(1\le n,m\le50\)\(1\le L\le10^8\),当 \(L>100\) 时,保证基本词汇长度不超过 \(2\)\(M\le20\)

基本词汇长度之和、忌讳词语的长度之和不超过 \(100\),基本词汇不重复,禁忌词汇不重复。


由于字符串中不能出现忌讳词语,而判断是否出现忌讳词语就相当于匹配这个忌讳词语,容易想到要把所有忌讳词语插入 AC 自动机。

然后构建 fail 树并把所有不能到达的点做标记,然后在 AC 自动机上 DP,每次考虑加入一个基本词汇,然后枚举可以转移到的状态,为了使复杂度更优秀,我们可以 \(O(n^3)\) 暴力预处理每个点插入每个字符串会到达的点,这样把我们 DP 的复杂度从 \(O(L\times n^3)\) 优化到了 \(O(L\times n^2)\)

这样,我们就解决了第一部分,本题另外一部分是 \(L\le10^8\),但保证基本词汇长度不超过 \(2\),这明显提示我们写矩乘。

这里已经不是这题在 AC 自动机方面作为例题的作用所在了,但还是说一下如何写吧。

主要还是建矩阵嘛。

设我们前面预处理那个数组是 \(p[i][j]\),表示在点 \(i\) 添加串 \(j\) 到达的位置。转移都是形如下面这样的:

\[f[i][j]\rarr f[p[i][k]][j+l_k] \]

除了填完填了一位时不能填长度为 \(2\) 的串以外,其他时候,转移都是随便填,这样就是个常系数的递推。

设初始列向量为 \(f[0\sim ct][1],f[0\sim ct][0]\),然后矩阵的前 \(ct+1\) 行就是看是否有转移,后 \(ct+1\) 行是继承旧的前半部分。

当然,看是否有转移其实麻烦了,因为那样就还需要一个逆映射,没有必要,我们有 \(p\) 这个映射就够了,我们仍然枚举所有转移,然后寻找转移在矩阵上的位置就行了。

#include<cstdio>
#include<cstring>
#include<queue>

using namespace std;

const int N = 100, md = 1000000007;
typedef long long ll;

struct Matrix {
    int n, m;
    ll a[N<<1][N<<1];
    void id() {
        for(int i = 1; i <= n; i ++)
            a[i][i] = 1ll;
    }
    void init(int _n, int _m) {
        n = _n;
        m = _m;
        memset(a, 0, sizeof(a));
    }
    Matrix operator * (Matrix B) {
        Matrix res;
        res.init(n, B.m);
        for(int i = 1; i <= n; i ++)
            for(int j = 1; j <= B.m; j ++)
                for(int k = 1; k <= m; k ++)
                    res.a[i][j] = (res.a[i][j] + a[i][k] * B.a[k][j]) % md;
        return res;
    }
    Matrix qpow(int b) {
        Matrix res, mat = *this;
        res.init(n, n);
        res.id();
        for(; b; b >>= 1, mat = mat * mat)
            if(b & 1)
                res = res * mat;
        return res;
    }
};

int n, m, l;
char s[N][N], tmps[N];
int len[N];
int ct, tr[N][26], fail[N];
bool tag[N];    // 忌讳词语标记
int ctb, hd[N], ver[N<<1], nxt[N<<1];
int p[N][N];   // p[i][j] 表示在点 i,添加第 j 个串会到达哪里
int f[N][N]; // 在点 i,填了 j 个字符的方案数
Matrix st, trans;  // 初始列向量,状态转移矩阵

void addedge(int u, int v) {
    ver[++ctb] = v;
    nxt[ctb] = hd[u];
    hd[u] = ctb;
}

void Insert(char *s) {
    int pos = 0;
    int len = strlen(s);
    for(int i = 0; i < len; i ++) {
        int tmp = s[i] - 'a';
        if(tr[pos][tmp] == 0)
            tr[pos][tmp] = ++ct;
        pos = tr[pos][tmp];
    }
    tag[pos] = true;
}

void buildfail() {
    queue<int> q;
    for(int i = 0; i < 26; i ++)
        if(tr[0][i])
            q.push(tr[0][i]);
    while(q.empty() == false) {
        int x = q.front();
        q.pop();
        for(int i = 0; i < 26; i ++)
            if(tr[x][i]) {
                fail[tr[x][i]] = tr[fail[x]][i];
                q.push(tr[x][i]);
            } else
                tr[x][i] = tr[fail[x]][i];
    }
}

void buildfailtree() {
    for(int i = 1; i <= ct; i ++) {
        addedge(i, fail[i]);
        addedge(fail[i], i);
    }
}

void dfs(int x, int fa, bool flg) {
    flg |= tag[x];
    tag[x] = flg;
    for(int i = hd[x]; i; i = nxt[i]) {
        int y = ver[i];
        if(y == fa)
            continue;
        dfs(y, x, flg);
    }
}

void getp() {
    for(int i = 0; i <= ct; i ++) { // 预处理 p[i][j]
        if(tag[i])
            continue;
        for(int j = 1; j <= n; j ++) {
            int pos = i;
            for(int k = 1; k <= len[j]; k ++) {
                int tmp = s[j][k] - 'a';
                pos = tr[pos][tmp];
                if(tag[pos]) {
                    p[i][j] = -1;
                    break;
                }
            }
            if(p[i][j] != -1)
                p[i][j] = pos;
        }
    }
}

void solv1() {
    int ans = 0;
    f[0][0] = 1;
    for(int i = 0; i <= l; i ++)
        for(int j = 1; j <= n; j ++) {
            if(i + len[j] > l)
                continue;
            for(int k = 0; k <= ct; k ++) {
                if(tag[k] || p[k][j] == -1)
                    continue;
                f[p[k][j]][i + len[j]] = (f[p[k][j]][i + len[j]] + f[k][i]) % md;
            }
        }
    for(int i = 0; i <= ct; i ++)
        ans = (ans + f[i][l]) % md;
    printf("%d\n", ans);
}

void solv2() {
    int ans = 0;
    f[0][0] = 1;
    for(int i = 1; i <= n; i ++) {
        if(len[i] > 1)
            continue;
        for(int j = 0; j <= ct; j ++) {
            if(tag[j] || p[j][i] == -1)
                continue;
            f[p[j][i]][1] = (f[p[j][i]][1] + f[j][0]) % md;
        }
    }
    st.init(2*(ct+1), 1);
    for(int i = 1; i <= ct+1; i ++) {   // 初始列向量
        st.a[i][1] = f[i-1][1];
        st.a[i+(ct+1)][1] = f[i-1][0];
    }
    trans.init(2*(ct+1), 2*(ct+1));
    for(int i = ct+2; i <= 2*ct+2; i ++)    // 下边直接继承的部分
        trans.a[i][i-(ct+1)] = 1;
    for(int i = 1; i <= n; i ++)  // 上边的部分
        for(int j = 0; j <= ct; j ++) {
            if(tag[j] || p[j][i] == -1)
                continue;
            if(len[i] == 1)
                trans.a[p[j][i] + 1][j+1] ++;
            else
                trans.a[p[j][i] + 1][(j+1) + (ct+1)] ++;
        }
    trans = trans.qpow(l-1);
    st = trans * st;
    for(int i = 1; i <= ct+1; i ++)
        ans = (ans + st.a[i][1]) % md;
    printf("%d\n", ans);
}

int main() {
    scanf("%d%d%d", &n, &m, &l);
    for(int i = 1; i <= n; i ++) {
        scanf("%s", s[i]+1);
        len[i] = strlen(s[i]+1);
    }
    for(int i = 1; i <= m; i ++) {
        scanf("%s", tmps);
        Insert(tmps);
    }
    buildfail();
    buildfailtree();
    dfs(0, -1, false);
    getp();
    if(l <= 100)
        solv1();
    else
        solv2();
    return 0;
}

posted @ 2022-03-07 14:21  RevolutionBP  阅读(30)  评论(0编辑  收藏  举报