牛客多校第二场

A - All with Pairs (AC自动机+kmp)

题意

定义\(f(s, t)\)为字符串s的前缀和t的后缀的最长公共长度。给定n个字符串\(s_1,s_2,...,s_n\)\(\sum_{i=1}^{n}{\sum_{j=1}^{n}{f(s_i, s_j)^2}}(\mod 998244353)\)

思路

AC自动机的trie树的从根节点开始的路径都是某个单词的前缀,fail指针跳的是当前前缀的后缀。所以可以预处理出每个前缀的贡献(记录在trie树上)。

然后匹配一个由所有单词连接而成的文本串,用一个特殊字符隔开。每次匹配到特殊字符,说明已经到达当前单词的结尾。然后跳fail累加前缀的贡献即可。

由于跳fail会导致重复累加贡献,所以一开始需要kmp预处理每个单词的前缀贡献的差分。

#include <bits/stdc++.h>
#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 1e6 + 10;
const int M = 998244353;
int tr[N][30];
int fail[N * 30]; 
int si = 0;
string s;
string tex;
ll add[N];
ll cont[N * 30];
int nt[N];


void getnext(const char t[]) {
    int i = 0, j = -1;
    nt[i] = j;
    int n = 0;
    while(t[i]) {
        n++;
        if(j == -1 || t[i]  == t[j]) {
            i++;
            j++;
            nt[i] = j;
        } else {
            j = nt[j];
        }
    }

    for(int i = n; i >= 1; i--) {
        add[i] = ((1ll * i * i - 1ll * nt[i] * nt[i]) % M + M) % M;
    }
}

namespace AC {
    void init() {
        memset(tr[0], 0,sizeof tr[0]); 
        memset(cont, 0,sizeof cont);
        si = 0;
    }

    void insert(const char s[]) { 
        int cur = 0;
        fail[cur] = 0;
        for(int i = 0; s[i]; i++) {
            if(tr[cur][s[i] - 'a']) cur = tr[cur][s[i] - 'a'];
            else {
                tr[cur][s[i] - 'a'] = ++si;
                cur = si;
                memset(tr[si], 0, sizeof tr[si]);
                fail[cur] = 0;
                cont[cur] = 0;
            }
            cont[cur] += add[i + 1]; 
            cont[cur] %= M;
        }
    }

    void build() {
        queue<int> q;
        for(int i = 0; i < 27; i++) 
            if(tr[0][i]) q.push(tr[0][i]);

        while(!q.empty()) {
            int cur = q.front();
            q.pop();
            for(int i = 0; i < 26; i++) {
                if(tr[cur][i]) {
                    fail[tr[cur][i]] = tr[fail[cur]][i]; 
                    q.push(tr[cur][i]);
                } else {
                    tr[cur][i] = tr[fail[cur]][i];
                }
            }
        } 
    }

    ll query(const char s[]) { //返回有多少模式串在s中出现过
        ll ans = 0;
        int cur = 0;
        for(int i = 0; s[i]; i++) {
            cur = tr[cur][s[i] - 'a'];
            if(s[i + 1] == '{')
                for(int j = cur; j; j = fail[j]) {
                    ans += cont[j];
                    ans %= M; 
                }
        }
        return ans;
    }
}

int main() {
    IOS;
    int n;
    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> s;
        getnext(s.c_str());
        AC::insert(s.c_str());
        tex += s + "{";
    }
    AC::build();
    ll ans = AC::query(tex.c_str());
    cout << (ans % M + M) % M << endl;
}
posted @ 2020-08-04 21:37  limil  阅读(71)  评论(0编辑  收藏  举报