Loading

「学习笔记」字符串哈希

定义

我们定义一个把字符串映射到整数的函数 \(f\),这个 \(f\) 称为是 Hash 函数.

我们希望这个函数 \(f\) 可以方便地帮我们判断两个字符串是否相等.

核心思想

Hash 的核心思想在于, 将输入映射到一个值域较小、可以方便比较的范围.

值域需要小到能够接受线性的空间与时间复杂度.

在字符串哈希中,值域需要小到能够快速比较 (\(10^9\)\(10^{18}\) 都是可以快速比较的).

同时, 为了降低哈希冲突率, 值域也不能太小.

性质

具体来说, 哈希函数最重要的性质可以概括为下面两条:

  1. 在 Hash 函数值不一样的时候, 两个字符串一定不一样;

  2. 在 Hash 函数值一样的时候, 两个字符串不一定一样 (但有大概率一样,且我们当然希望它们总是一样的).

我们将 Hash 函数值一样但原字符串不一样的现象称为哈希碰撞.

实现

通常我们采用的是多项式 Hash 的方法, 对于一个长度为 \(l\) 的字符串 \(s\) 来说, 我们可以这样定义多项式 Hash 函数: \(f(s) = \sum_{i=1}^{l} s_i \times b^{l-i} \pmod M\).

例如, 对于字符串 xyz, 其哈希函数值为 \(xb^2+yb+z\).

for (int i = 1; i <= l; ++ i) {
	cin >> s[i];
	Hash[i] = Hash[i - 1] * based + (s[i] - 'a' + 1);
	p[i] = p[i - 1] * based;
	sump[i] = sump[i - 1] + p[i];
}

快速计算子串的哈希值

\(f_i(s)\) 表示 \(f(s[1 \sim i])\), 即原串长度为 \(i\) 的前缀的哈希值, 那么按照定义有 \(f_i(s)=s_1 \cdot b^{i-1} + s_2 \cdot b^{i-2}+ \dots + s_{i-1} \cdot b + s_i\).

现在, 我们想要用类似前缀和的方式快速求出 \(f(s[l \sim r])\), 按照定义有字符串 s[l ~ r] 的哈希值为 \(f(s[l \sim r])=s_l \cdot b^{r-l} + s_{l+1} \cdot b^{r-l-1} + \dots + s_{r-1} \cdot b + s_r\).

对比观察上述两个式子, 我们发现 \(f(s[l \sim r])=f_r(s)-f_{l-1}(s) \times b^{r-l+1}\) 成立 (可以手动代入验证一下), 因此我们用这个式子就可以快速得到子串的哈希值. 其中 \(b^{r-l+1}\) 可以 \(O_n\) 的预处理出来然后 \(O_1\) 的回答每次询问 (当然也可以快速幂 \(O_{\log n}\) 的回答每次询问).

例题

P9453 [ZSHOI-R1] 有效打击 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

这道题的正解是 KMP,结果被暴力的字符串哈希整过去了 (?)

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 5e6 + 5;
const int based = 29;

int n, m, cnt;
int a[N], b[N], cur[N], len[N];
ll num[N];
ull Hasha[N], Hashb[N], p[N], sump[N];

int gcd(ll x, ll y) {
	if (y == 0) {
		return x;
	}
	return gcd(y, x % y);
}

ull gethash(int l, int r) {
	return Hasha[r] - Hasha[l - 1] * p[r - l + 1];
}

void solve() {
	int g = len[1], Len = 0, cot = 0, mx = 0;
	for (int i = 1; i <= n; ++ i) {
		if (a[i] != a[i - 1]) {
			mx = max(mx, cot);
			cot = 0;
		}
		++ cot;
	}
	mx = max(mx, cot);
	for (int i = 2; i <= cnt; ++ i) {
		g = gcd(g, len[i]);
	}
	for (int i = 1; i <= cnt; ++ i) {
		len[i] /= g;
		Len += len[i];
	}
	ll ans = 0;
	for (int k = 1; k * Len <= n && k <= mx + 1; ++ k) {
		ull Hash = 0;
		for (int i = 1; i <= cnt; ++ i) {
			Hash = Hash * p[k * len[i]] + cur[i] * sump[k * len[i] - 1];
		}
		for (int l = 1; l + k * Len - 1 <= n; ++ l) {
			int r = l + k * Len - 1;
			if (gethash(l, r) == Hash) {
				++ ans;
			}
		}
	}
	printf("%lld\n", ans);
}

int main() {
	n = read<int>(), m = read<int>();
	p[0] = 1, sump[0] = 1;
	for (int i = 1; i <= n; ++ i) {
		a[i] = read<int>();
		num[i] = num[i - 1] + i;
		Hasha[i] = Hasha[i - 1] * based + a[i];
		p[i] = p[i - 1] * based;
		sump[i] = sump[i - 1] + p[i];
	}
	for (int i = 1; i <= m; ++ i) {
		b[i] = read<int>();
		if (b[i] != b[i - 1]) {
			cur[++ cnt] = b[i];
		}
		++ len[cnt];
	}
	if (cnt == 1) {
		int len = 0;
		ll ans = 0;
		for (int i = 1; i <= n; ++ i) {
			if (a[i] != a[i - 1]) {
				if (a[i - 1] == b[1]) {
					ans += num[len];
				}
				len = 1;
			}
			else {
				++ len;
			}
		}
		if (len && a[n] == b[1]) {
			ans += num[len];
		}
		printf("%lld\n", ans);
		return 0;
	}
	solve();
	return 0;
}
posted @ 2023-01-09 09:12  yi_fan0305  阅读(66)  评论(0编辑  收藏  举报