BZOJ 4084 【Sdoi2015】双旋转字符串

题意:
有两个字符串集合\(S\)\(T\)\(S\)中所有字符串长度为\(n\)\(T\)中所有字符串长度为\(m\),保证\(n + m\)为偶数。
询问有多少对\((i, j)\)使得\(S_i\)\(T_j\)拼接之后得到的字符串满足双旋转性。
双旋转性:
将字符串分成等长的两部分,这两部分循环同构。

思路:
考虑\(Rabin-Karp\)

\[\begin{eqnarray*} H(s) = \sum\limits_{i = 0}^{|S| - 1} base^i \cdot s[i] \end{eqnarray*} \]

注意到\(n \geq m\),那么拼接之后必然在\(S_i\)中的末尾\(n - \frac{n + m}{2}\)的部分是固定拼接在\(T_j\)前面,并成为分割后的后半部分。
那么直接把\(T_j\)\(Hash\)值丢进\(map\),然后枚举\(S_i\)\(\frac{n + m}{2}\)部分的循环同构,处理一下\(S_i\)后半部分与\(T_j\)的拼接,然后计算答案即可。

如何枚举\(S_i\)\(\frac{n + m}{2}\)部分的循环同构?
考虑一个字符串\(s_1s_2s_3 \cdots s_n\),我们要得到\(s_ns_1s_2\cdots s_{n - 1}\)\(Hash\)值。
那么就是先将\(H(s)\)减去\(Base^{i - 1} \cdot s_n\),再乘上\(Base\),再加上\(s_n\)即可。

代码:

#include <bits/stdc++.h>
using namespace std;
 
#define ll long long
#define ull unsigned long long
const ull basep = 31;
int n, m, lens, lent;
vector <string> s, t;
vector <ull> Ht;
map <ull, int> used, has;    
ull qmod(ull base, int n) {
    ull res = 1;
    while (n) {
        if (n & 1) {
            res = res * base;
        }
        base = base * base;
        n >>= 1;
    }
    return res;
}
void Hash(vector <string> &s, vector <ull> &H, int sze, int len, ull base) {
    H.clear(); H.resize(sze); 
    for (int i = 0; i < sze; ++i) {
        H[i] = 0;
        ull tmp = base; 
        for (int j = 0; j < len; ++j) {
            H[i] = H[i] + tmp * s[i][j];
            tmp *= basep;
        }
    }
}
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    while (cin >> n >> m >> lens >> lent) {
        s.clear(); s.resize(n);
        t.clear(); t.resize(m); 
        Ht.clear(); Ht.resize(m);
        for (int i = 0; i < n; ++i) cin >> s[i];
        for (int i = 0; i < m; ++i) cin >> t[i];
        Hash(t, Ht, m, lent, qmod(basep, lens - (lens + lent) / 2 + 1));
        has.clear(); for (int i = 0; i < m; ++i) ++has[Ht[i]];
        int need = (lens + lent) / 2;
        ll res = 0;
        ull D = qmod(basep, need);
        for (int i = 0; i < n; ++i) {
            used.clear(); 
            ull L = 0, R = 0, base; 
            base = basep;
            for (int j = need; j < lens; ++j) {
                R = R + base * s[i][j];
                base *= basep;
            }
            base = basep;
            for (int j = 0; j < need; ++j) {
                L = L + base * s[i][j];
                base *= basep;    
            }
            for (int j = need - 1; j >= 0; --j) { 
                if (used.find(L) == used.end() && has.find(L - R) != has.end()) {
                    used[L] = 1;
                    res += has[L - R];  
                }
                L -= D * s[i][j];
                L *= basep;
                L += basep * s[i][j]; 
            }
        }
        cout << res << "\n";
    }
    return 0;
}
posted @ 2019-07-09 19:16  Dup4  阅读(116)  评论(0编辑  收藏  举报