cf 1202E. You Are Given Some Strings... ac自动机

传送门

就是多个子串,1个主串,然后需要把任意两个子串进行连接,求连接后的字符串在主串中出现的次数和。

可以想到,枚举主串的每一个字符,那么统计子串中以当前字符为结尾的子串个数,同时统计以这个字符后面的一个字符为开始的子串个数,两个相乘就是当前字符的贡献值。

可以发现,两个任务其实是一样的,只要把主串,子串都进行翻转一次,那么利用统计以字符为结尾,就能统计出以字符为开始的个数了。

那么开两个ac自动机,进行统计以字符为结尾的子串个数。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5 + 5, M = 2e6 + 5;
template<typename T = long long> inline T read() {
    T s = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) {s = (s << 3) + (s << 1) + ch - 48; ch = getchar();} 
    return s * f;
}
struct AC{
    int nex[N][26], fail[N], siz[N], end[N], tot, n;
    void init(int nn){
        tot = 0;
    }
    void insert(char *s, int nu){
        int now = 0;
        for (int j = 0; s[j]; ++j){
            int c = s[j] - 'a';
            if (!nex[now][c]) nex[now][c] = ++tot;
            now = nex[now][c];
        }
        end[now]++;
    }
    void getfail(){
        queue<int> q;
        for(int i = 0; i < 26; i++) {
            if(nex[0][i]) fail[nex[0][i]] = 0, q.push(nex[0][i]);
        }
        while (!q.empty()){
            int u = q.front();
            q.pop();
            for (int i = 0; i < 26; ++i){
                if (nex[u][i]){
                    fail[nex[u][i]] = nex[fail[u]][i];
                    end[nex[u][i]] += end[fail[nex[u][i]]]; // 关键点
                    q.push(nex[u][i]);
                }
                else nex[u][i] = nex[fail[u]][i];
            }
        }
    }
    void query(char *s){
        int now = 0;
        int id = 0;
        for(int i = 0; s[i]; i++) {
            now = nex[now][s[i] - 'a'];
            siz[i] = end[now];
        }
    }
} ac1, ac2;
char s[N], t[N];
int main(){
    scanf("%s", s);
    int n = read();
    ac1.init(n); ac2.init(n);
    for(int i = 1; i <= n; i++) {
        scanf("%s", t);
        ac1.insert(t, i);
        reverse(t, t + strlen(t));
        ac2.insert(t, i);
    }
    int len = strlen(s);
    ac1.getfail(), ac2.getfail();
    ac1.query(s); reverse(s, s + len);
    ac2.query(s); 
    ll ans = 0;
    for(int i = 0; i < len - 1; i++) {
        ans += 1ll * ac1.siz[i] * ac2.siz[len - i - 2];
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2021-01-24 19:43  Emcikem  阅读(74)  评论(0编辑  收藏  举报