「学习笔记」字符串哈希
定义
我们定义一个把字符串映射到整数的函数 \(f\),这个 \(f\) 称为是 Hash 函数.
我们希望这个函数 \(f\) 可以方便地帮我们判断两个字符串是否相等.
核心思想
Hash 的核心思想在于, 将输入映射到一个值域较小、可以方便比较的范围.
值域需要小到能够接受线性的空间与时间复杂度.
在字符串哈希中,值域需要小到能够快速比较 (\(10^9\)、\(10^{18}\) 都是可以快速比较的).
同时, 为了降低哈希冲突率, 值域也不能太小.
性质
具体来说, 哈希函数最重要的性质可以概括为下面两条:
-
在 Hash 函数值不一样的时候, 两个字符串一定不一样;
-
在 Hash 函数值一样的时候, 两个字符串不一定一样 (但有大概率一样,且我们当然希望它们总是一样的).
我们将 Hash 函数值一样但原字符串不一样的现象称为哈希碰撞.
实现
通常我们采用的是多项式 Hash 的方法, 对于一个长度为 \(l\) 的字符串 \(s\) 来说, 我们可以这样定义多项式 Hash 函数: \(f(s) = \sum_{i=1}^{l} s_i \times b^{l-i} \pmod M\).
例如, 对于字符串 xyz
, 其哈希函数值为 \(xb^2+yb+z\).
for (int i = 1; i <= l; ++ i) {
cin >> s[i];
Hash[i] = Hash[i - 1] * based + (s[i] - 'a' + 1);
p[i] = p[i - 1] * based;
sump[i] = sump[i - 1] + p[i];
}
快速计算子串的哈希值
令 \(f_i(s)\) 表示 \(f(s[1 \sim i])\), 即原串长度为 \(i\) 的前缀的哈希值, 那么按照定义有 \(f_i(s)=s_1 \cdot b^{i-1} + s_2 \cdot b^{i-2}+ \dots + s_{i-1} \cdot b + s_i\).
现在, 我们想要用类似前缀和的方式快速求出 \(f(s[l \sim r])\), 按照定义有字符串 s[l ~ r]
的哈希值为 \(f(s[l \sim r])=s_l \cdot b^{r-l} + s_{l+1} \cdot b^{r-l-1} + \dots + s_{r-1} \cdot b + s_r\).
对比观察上述两个式子, 我们发现 \(f(s[l \sim r])=f_r(s)-f_{l-1}(s) \times b^{r-l+1}\) 成立 (可以手动代入验证一下), 因此我们用这个式子就可以快速得到子串的哈希值. 其中 \(b^{r-l+1}\) 可以 \(O_n\) 的预处理出来然后 \(O_1\) 的回答每次询问 (当然也可以快速幂 \(O_{\log n}\) 的回答每次询问).
例题
P9453 [ZSHOI-R1] 有效打击 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
这道题的正解是 KMP,结果被暴力的字符串哈希整过去了 (?)
/*
The code was written by yifan, and yifan is neutral!!!
*/
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template<typename T>
inline T read() {
T x = 0;
bool fg = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {
fg |= (ch == '-');
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
return fg ? ~x + 1 : x;
}
const int N = 5e6 + 5;
const int based = 29;
int n, m, cnt;
int a[N], b[N], cur[N], len[N];
ll num[N];
ull Hasha[N], Hashb[N], p[N], sump[N];
int gcd(ll x, ll y) {
if (y == 0) {
return x;
}
return gcd(y, x % y);
}
ull gethash(int l, int r) {
return Hasha[r] - Hasha[l - 1] * p[r - l + 1];
}
void solve() {
int g = len[1], Len = 0, cot = 0, mx = 0;
for (int i = 1; i <= n; ++ i) {
if (a[i] != a[i - 1]) {
mx = max(mx, cot);
cot = 0;
}
++ cot;
}
mx = max(mx, cot);
for (int i = 2; i <= cnt; ++ i) {
g = gcd(g, len[i]);
}
for (int i = 1; i <= cnt; ++ i) {
len[i] /= g;
Len += len[i];
}
ll ans = 0;
for (int k = 1; k * Len <= n && k <= mx + 1; ++ k) {
ull Hash = 0;
for (int i = 1; i <= cnt; ++ i) {
Hash = Hash * p[k * len[i]] + cur[i] * sump[k * len[i] - 1];
}
for (int l = 1; l + k * Len - 1 <= n; ++ l) {
int r = l + k * Len - 1;
if (gethash(l, r) == Hash) {
++ ans;
}
}
}
printf("%lld\n", ans);
}
int main() {
n = read<int>(), m = read<int>();
p[0] = 1, sump[0] = 1;
for (int i = 1; i <= n; ++ i) {
a[i] = read<int>();
num[i] = num[i - 1] + i;
Hasha[i] = Hasha[i - 1] * based + a[i];
p[i] = p[i - 1] * based;
sump[i] = sump[i - 1] + p[i];
}
for (int i = 1; i <= m; ++ i) {
b[i] = read<int>();
if (b[i] != b[i - 1]) {
cur[++ cnt] = b[i];
}
++ len[cnt];
}
if (cnt == 1) {
int len = 0;
ll ans = 0;
for (int i = 1; i <= n; ++ i) {
if (a[i] != a[i - 1]) {
if (a[i - 1] == b[1]) {
ans += num[len];
}
len = 1;
}
else {
++ len;
}
}
if (len && a[n] == b[1]) {
ans += num[len];
}
printf("%lld\n", ans);
return 0;
}
solve();
return 0;
}