牛客第二场_A_All with Pairs

题意

给定n个字符串,求所有字符串前缀与后缀相等的个数与前后缀的长度的平方的和。如样例,匹配长度为1,2,3的分别有4,4,1个,所以答案为
4 *1 ^2 +4 *2 ^2 +1 *3 ^2=29

思路

可以枚举每一个字符的前缀看该前缀能和那些后缀组成f(t, s),然后在减去枚举时多出的部分即可。构建AC自动机,利用fail数组构建成
一个树(fail的每条边反向就可以构成一颗树),利用这棵树来对cnt数组进行累加。每次遍历一个字符串的前缀时构建一个KMP的ne数组,
答案累加cnt[p]*len*len(cnt[p]为该前缀能与多少后缀匹配,len是该前缀的长度)在减去cnt[p]*ne[j]*ne[j](表示ne[j]这个前缀和
自己的后缀(j结尾)匹配成功的情况)

代码

#pragma GCC optimize(2)
#include<unordered_map>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define Buff ios::sync_with_stdio(false)
#define rush() int Case = 0; int T; cin >> T;  while(T--)
#define rep(i, a, b) for(int i = a; i <= b; i ++)
#define per(i, a, b) for(int i = a; i >= b; i --)
#define reps(i, a, b) for(int i = a; b; i ++)
#define clc(a, b) memset(a, b, sizeof(a))
#define Buff ios::sync_with_stdio(false)
#define readl(a) scanf("%lld", &a)
#define readd(a) scanf("%lf", &a)
#define readc(a) scanf("%c", &a)
#define reads(a) scanf("%s", a)
#define read(a) scanf("%d", &a)
#define lowbit(n) (n&(-n))
#define pb push_back
#define lson rt<<1
#define rson rt<<1|1
#define ls lson, l, mid
#define rs rson, mid+1, r
#define y second
#define x first
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>PII;
const int mod = 998244353;
const double eps = 1e-6;
const int N = 1e6+1;
const int M = 1e5+7;
string s[M];
ll tr[N][26], ne[N], q[N], idx;
ll cnt[N], nxt[N], id[N];
vector<int> G[N];
void insert(int x)
{
    int p = 0;
    reps(i, 1, s[x][i])
    {
        int t = s[x][i] - 'a';
        if(!tr[p][t])   tr[p][t] = ++idx;
        p = tr[p][t];
    }
    cnt[p] ++;
    id[x] = p;
}
void build()
{
    int hh = 0, tt = -1;
    rep(i, 0, 25)   if(tr[0][i])    q[++ tt] = tr[0][i];
    while(hh <= tt)
    {
        int t = q[hh ++];
        rep(i, 0, 25)
        {
            ll& p = tr[t][i];
            // printf("tr[%d][%d]: %lld ----- %lld\n", t, i, tr[t][i], tr[ne[t]][i]);
            if(!p)  p = tr[ne[t]][i];
            else
            {
                ne[p] = tr[ne[t]][i];
                q[++ tt] = p;    
            }
        }
    }
}
void dfs(int x)
{
    sort(G[x].begin(), G[x].end());
    G[x].erase(unique(G[x].begin(), G[x].end()), G[x].end());
    for(int t:G[x])
    {
        dfs(t);
        cnt[x] += cnt[t];
    }
}
void kmp(int x)
{
    for(int i = 2, j = 0; s[x][i]; i ++)
    {
        while(j && s[x][i] != s[x][j+1])    j = nxt[j];
        if(s[x][i] == s[x][j+1]) j ++;
        nxt[i] = j;
    }
}
int main()
{
    int n;
    cin >> n;
    rep(i, 1, n)
    {
        cin >> s[i];    s[i] = " " + s[i];
        insert(i);
    }
    build();
    for(int i = 1; i <= n; i ++)
    {
        int p = id[i];
        while(p)
        {
            G[ne[p]].push_back(p);
            p = ne[p];
        }
    }
    dfs(0);
    ll res = 0;
    // for(int i = 1; i <= idx; i ++)
    // cout << cnt[i] <<"\n";
    for(int i = 1; i <= n; i ++)
    {
        kmp(i);
        int p = 0;
        for(int j = 1; s[i][j]; j ++)
        {
            p = tr[p][s[i][j]-'a'];
            // cout <<"i, j: " << i <<" "<< j <<" "<< (cnt[p] * j * j) % mod - (cnt[p] * nxt[j] * nxt[j]) <<endl;
            res = (res + (cnt[p] * j * j)%mod - (cnt[p]*nxt[j]*nxt[j])%mod)%mod;
            if(res < 0) res += mod;
        }
    }
    cout << res <<endl;
    return 0;
}

posted @ 2020-11-25 14:18  youngman-f  阅读(54)  评论(0编辑  收藏  举报