2020牛客暑期多校训练营(第二场) A All with Pairs

求第\(i\)个字符串的前缀最大能与所有的字符串的后缀匹配长度的平方和。
首先,字符串匹配相等的问题可以考虑字符串hash。这道题我们发现,hash所有的字符串需要\(\sum(len(s_{i}))\),是可以接受的。那么我们首先进行字符串的hash,统计每一个hash值有多少个。然后从头开始暴力所有的字符串,需要时间\(\sum(len(s_{i}))\),针对前缀第i个串第j个结尾的\(s[i][j]\),贡献应该为:\(num[hash_{j}]*(j+1)*(j+1)\)
但是,我们可以发现这样计算连样例都过不了。原因在于,例如\(aba\),我们计算他的前缀,有\(a\),\(aba\),这两个前缀分别计算了后缀\(a\)有多少个,\(aba\)有多少个。但这就带来的重复。因为后缀是\(aba\)的必然会对后缀\(a\)造成影响,因为\(a\)\(aba\)的一个后缀。所以我们需要对当前串\(s[i]\),进行kmp算法,算一个前缀第i个串第j个结尾的\(s[i][j]\),他和他的前缀的匹配重复结果即,\(next[j]\),剪掉重复计算的部分即可。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<climits>
#include<stack>
#include<vector>
#include<queue>
#include<set>
#include<bitset>
#include<map>
//#include<regex>
#include<cstdio>
#include <iomanip>
#pragma GCC optimize(2)
#define up(i,a,b)  for(int i=a;i<b;i++)
#define dw(i,a,b)  for(int i=a;i>b;i--)
#define upd(i,a,b) for(int i=a;i<=b;i++)
#define dwd(i,a,b) for(int i=a;i>=b;i--)
//#define local
typedef long long ll;
typedef unsigned long long ull;
const double esp = 1e-6;
const double pi = acos(-1.0);
const int INF = 0x3f3f3f3f;
const int inf = 1e9;
using namespace std;
ll read()
{
    char ch = getchar(); ll x = 0, f = 1;
    while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}
typedef pair<int, int> pir;
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lrt root<<1
#define rrt root<<1|1
const int N = 1e6 + 10;
const int base = 233;
const int mod = 998244353;
ull Hash[N];
ull poww[N];
int n;
string s[N];
map<ull, int>mp;
void init(int pos)
{
    int len = s[pos].length();
    Hash[len + 1] = 0;
    poww[len + 1] = 1;
    ull k = 1;
    dwd(i, len, 1)
    {
        Hash[i] = (s[pos][i - 1] - 'a' + 1)*k + Hash[i + 1];
        mp[Hash[i]]++;
        k *= base;
    }
    upd(i, 1, len)poww[i] = poww[i - 1] * base;
}
ull querry(int l, int r)
{
    return Hash[r] - Hash[l - 1] * (poww[r - l + 1]);
}
int nxt[N];
void getnext(int pos)
{
    nxt[0] = nxt[1] = 0;
    int len = s[pos].length();
    int j = 0;
    up(i, 1, len)
    {
        j = nxt[i];
        while (j&&s[pos][i] != s[pos][j])j = nxt[j];
        nxt[i + 1] = s[pos][i] == s[pos][j] ? j + 1 : 0;
    }
}
int main()
{
    n = read();
    upd(i, 1, n)
    {
        cin >> s[i];
        init(i);
    }
    ll ans = 0;
    upd(i, 1, n)
    {
        Hash[0] = 0; poww[0] = 1;
        int len = s[i].length();
        getnext(i);
        vector<ll>temp(len + 1);
        upd(j, 1, len)
        {
            Hash[j] = Hash[j - 1] * base + s[i][j - 1] - 'a' + 1;
            temp[j - 1] += mp[Hash[j]];
        }
        up(j, 0, len)
        {
            if(nxt[j+1])
                temp[nxt[j + 1] - 1] -= temp[j];
        }
        up(j, 0, len)
            ans = (ans + (temp[j] * (j + 1) % mod*(j + 1) % mod)) % mod;
    }
    cout << ans << endl;
    return 0;
}
posted @ 2020-07-15 16:11  LORDXX  阅读(129)  评论(0编辑  收藏  举报