All with Pairs

All with Pairs

题意:

给你n个字符串求, 字符串两两的最长前缀和后缀相等的长度 的平方和。

题解:

前置知识 :字符串hash, kmp的next数组

将字符串hash 将所有字符串的后缀hash值丢进map里。

然后枚举所有字符串的前缀的hash值。

查询 前缀hash值再mp中的个数求出答案。

用 next数组去重。

为啥要去重?

比如 aba

查出有 aba的后缀 , 枚举的时候是不是把 a 也算上了。

next数组怎么去重?

首先要理解next数组, next数组球的就是前缀和后缀相等的最大长度。

自己仔细想下就明白了。

map会超时所有我用 unorder_map, hash我用双hash。

代码:

#include<bits/stdc++.h>

using namespace std;

const int N = 1e6 + 7;
typedef long long ll;
string s[N];

const ll mod = 998244353;

class hash_str {
private:
    string str;
    const long long d = 238;
    const long long MOD = 1610612741;
    vector<long long>h, p;
    int n;
public:
    hash_str(string s) {
        this->str = s;
        this->n = str.length();
        h.resize(n + 1);
        p.resize(n + 1);
    }
    hash_str() {}
    void make_hash() {
        h[0] = str[0];
        p[0] = 1;
        for(int i = 1; i < n; i++) {
            h[i] = ((h[i - 1] * d ) % MOD + str[i] ) % MOD;
            p[i] = (p[i - 1] * d ) % MOD;
        }
    }
    long long  get_hash(int l, int r) {
        if(!l) return h[r];
        return (h[r] - (h[l - 1] % MOD * p[r - l + 1] % MOD) % MOD + MOD) % MOD;
    }

};

class hash_str1 {
private:
    string str;
    const long long d = 32;
    const long long MOD = 1e9 + 7;
    vector<long long>h, p;
    int n;
public:
    hash_str1(string s) {
        this->str = s;
        this->n = str.length();
        h.resize(n + 1);
        p.resize(n + 1);
    }
    hash_str1() {}
    void make_hash() {
        h[0] = str[0];
        p[0] = 1;
        for(int i = 1; i < n; i++) {
            h[i] = ((h[i - 1] * d ) % MOD + str[i] ) % MOD;
            p[i] = (p[i - 1] * d ) % MOD;
        }
    }
    long long  get_hash(int l, int r) {
        if(!l) return h[r];
        return (h[r] - (h[l - 1] % MOD * p[r - l + 1] % MOD) % MOD + MOD) % MOD;
    }

};

void get(string x, int next[]) {
    for (int i = 0; i <= x.length(); i++) next[i] = 0;
    int i = 0;
    int k = -1;
    int len = x.length();
    next[i] = -1;
    while (i < len) {
        if (k == -1 || x[i]  == x[k]) {
            next[++i] = ++k;
        } else {
            k = next[k];
        }
    }
}

struct pair_hash{
    template<class T1, class T2>
    std::size_t operator() (const std::pair<T1, T2>& p) const{
        auto h1 = std::hash<T1>{}(p.first);
        auto h2 = std::hash<T2>{}(p.second);
        return h1 ^ h2;
    }
};


unordered_map<pair<ll, ll>, ll, pair_hash>mp;
int n, nxt[N];
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> s[i];
    }
    for (int i = 1; i <= n; i++) {
        hash_str str(s[i]);
        str.make_hash();
        hash_str1 str1(s[i]);
        str1.make_hash();
        for (int j = 0; j < s[i].length(); j++) {
            mp[{str.get_hash(j, s[i].length() - 1), str1.get_hash(j, s[i].length() - 1)}]++;
        }

    }
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        hash_str str(s[i]);
        hash_str1 str1(s[i]);
        str.make_hash();
        str1.make_hash();
        get(s[i], nxt);
       
        int len = s[i].length();
        for (int j = len - 1; j >= 0; j--) {
            ll cnt = mp[{str.get_hash(0, j), str1.get_hash(0, j)}];
          
            ans = (ans + cnt % mod * 1ll * (j + 1)% mod * 1ll* (j + 1) % mod) % mod;
            if (nxt[j + 1] <= 0) continue;
            ans = (ans - 1ll * (nxt[j + 1]) * (nxt[j + 1]) * cnt % mod + mod) % mod; //去重
        }

    }
    printf("%lld\n", ans);



}
posted @ 2020-07-17 01:14  ccsu_zhaobo  阅读(132)  评论(0编辑  收藏  举报