[字符串专题] KMP、Hash、Trie
KMP
核心思想:在每次失配时,不是把 p
串往后移一位,而是把 p
串往后移动至下一次可以和前面部分匹配的位置,这样就可以跳过大多数的失配步骤。而每次 p
串移动的步数就是通过查找 next
数组确定的。
KMP主要分两步:求 next
数组、匹配字符串,其难点在于如何求 next
数组
for(int i = 1, j = 0; i <= n; i++)
{
while(j > 0 && s[i] != p[j + 1]) j = ne[j];
//如果j有对应p串的元素, 且s[i] != p[j + 1], 则失配, 移动p串
//用while是由于移动后可能仍然失配,所以要继续移动直到匹配或整个p串移到后面(j = 0)
if(s[i] == p[j + 1]) j++;
//当前元素匹配,j移向p串下一位
if(j == m)
{
//匹配成功,进行相关操作
j = next[j]; //继续匹配下一个子串
}
}
剩下的就简单了,我们以洛谷模板题 P3375 【模板】KMP 为例
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
const int N = 1e6 + 5;
char p[N], s[N];
int ne[N];
int pos[N], idx = 0;
int main()
{
cin >> s + 1 >> p + 1;
int n = strlen(s + 1);
int m = strlen(p + 1);
for (rint i = 2, j = 0; i <= m; i++)
{
while (j > 0 && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
for (rint i = 1, j = 0; i <= n; i++)
{
while (j > 0 && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j++;
if (j == m)
{
pos[++idx] = i - m + 1;
j = ne[j];
}
}
for (rint i = 1; i <= idx; i++)
{
cout << pos[i] << endl;
}
for (rint i = 1; i <= m; i++)
{
cout << ne[i] << " ";
}
return 0;
}
Hash
Hash
,一般翻译做散列、杂凑,或音译为哈希,是把任意长度的输入通过散列算法变换成固定长度的输出,该输出就是散列值。这种转换是一种压缩映射,也就是,散列值的空间通常远小于输入的空间,不同的输入可能会散列成相同的输出,所以不可能从散列值来确定唯一的输入值。简单的说就是一种将任意长度的消息压缩到某一固定长度的消息摘要的函数。
Hash
一般有两种实现方法,一种是直接用 map
,一种是手写
map
实现
map
是个很常见的 STL,常见操作就是映射 string
为一个 int
数组。
我们以 [JLOI2011] 不重复数字
对于这道水题目,其实有很多做法,但是我们想一想 map
怎么做。
我们开一个 unordered_map<int, bool> v
,然后对于每次输入的 x
,看看是否在 v
中存在,不存在就标记上,然后输出即可,对于每组数据要清空一次 map
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
int T, n;
unordered_map<int, bool> v;
int main()
{
cin >> T;
while (T--)
{
cin >> n;
v.clear();
for (rint i = 1; i <= n; i++)
{
int x;
cin >> x;
if (!v[x])
{
cout << x << " ";
v[x] = 1;
}
}
cout << endl;
}
return 0;
}
我们再来看 P1381 单词背诵
这个题我们可以直接开两个 map
map<string, int> sum
存储这个单词在文章中有多少个map<string, bool> v
存储这个单词是否出现
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
const int N = 1e5 + 5;
map<string, int> sum;
map<string, bool> v;
int ans1, ans2;
int n, m;
int hh = 1;
string s[N];
int main()
{
cin >> n;
for (rint i = 1; i <= n; i++)
{
string a;
cin >> a;
v[a] = 1;
}
cin >> m;
for (rint i = 1; i <= m; i++)
{
cin >> s[i];
if (v[s[i]])
{
sum[s[i]]++;
}
if (sum[s[i]] == 1)
{
ans1++;
ans2 = i - hh + 1;//i 可以看成当前段的尾巴
}
while (hh <= i)
{
if (!v[s[hh]])
{
hh++;
continue;
}
if (sum[s[hh]] >= 2)
{
sum[s[hh]]--;
hh++;
continue;
}
break;
}
ans2 = std::min(ans2, i - hh + 1);
}
cout << ans1 << endl << ans2 << endl;
return 0;
}
当然,map
也不是万能的,在某些数据中,map
会被卡,一般会炸空间。
以 [CSP-S 2023] 消消乐 为例
通过 map
,我们可以快速实现这个题:
#include<bits/stdc++.h>
#define rint register int
#define int long long
#define x first
#define y second
using namespace std;
map<string, int> m;
string s, k;
int n, ans;
signed main()
{
cin >> n >> s;
m[""]++;
for(auto i = s.begin(); i != s.end(); i++)
{
auto tail = k.end();
if(tail != k.begin())
{
tail--;
if((*tail) == *i)
{
k.erase(tail);
m[k]++;
}
else
{
k += *i;
m[k]++;
}
}
else
{
k += *i;
m[k]++;
}
}
for(auto i = m.begin(); i != m.end(); i++)
{
ans += (*i).y*((*i).y - 1) / 2;
}
cout << ans << endl;
return 0;
}
然后呢?洛谷民间数据 80pts
or 90pts
,小图灵民间数据 65pts
所以我们就要用到正经的 hash
来解决了,但是很麻烦,这里建议用字典树,第三个算法会给出此题正解。
正常 Hash
实现
我们以 P3370 【模板】字符串哈希 为例
我们已经知道,这个题可以直接 map
查询当前字符串是否存在过就可以了。我们知道, map
有炸空间的坏处,所以如何使用正经 hash
呢?
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
const int base = 131;
const long long mod = 1145141919ll;
//base 一般取 131,mod 随便取
const int N = 1e6 + 5;
int n;
unsigned long long a[N];
char s[N];
unsigned long long Hash(char s[])
//使一个字符串有它自己的值
{
int len = strlen(s + 1);
unsigned long long ans = 0;
for (rint i = 1; i <= len; i++)
{
ans = (ans * base + (unsigned long long)(s[i])) % mod;
}
return ans;
}
int main()
{
cin >> n;
for (rint i = 1; i <= n; i++)
{
scanf("%s", s + 1);
a[i] = Hash(s);
}
sort(a + 1, a + n + 1);
cout << unique(a + 1, a + n + 1) - (a + 1) << endl;
return 0;
}
我们再来看一道 [CTSC2014] 企鹅 QQ
题目大意为给定 \(n\) 个字符串 问两个字符串只差一个字符的字符串对的数量
相比于普通 hash,此题要做以下两点:
- 1.预处理每一位之间的关系,方便 check
- 2.预处理出所有 hash 值,降低单次查询计算时空复杂度
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
const int N = 3e4 + 5;
const int M = 2e2 + 5;
const int base = 131;
int n, l, s;
unsigned long long h[N][M],t[N];
unsigned long long q[M];
int main()
{
cin >> n >> l >> s;
for (rint i = 1; i <= n; i++)
{
for (rint j = 1; j <= l; j++)
{
char c;
cin >> c;
h[i][j] = h[i][j - 1] * base + c;
}
}
q[0] = 1;
for (rint i = 1; i <= l; i++)
//预处理出每一位
{
q[i] = q[i - 1] * base;
}
int ans = 0;
for (rint i = 1; i <= l; i++)
{
for (rint j = 1; j <= n; j++)
{
t[j] = h[j][l] - (h[j][i] - h[j][i - 1] * base) * q[l - i] - h[j][i - 1] * (q[l - i + 1] - q[l - i]);
//去掉给第 j 个字符串第 i 位后的 hash 值, 是整个的 hash 值
//减掉这个位置的hash * Hina[l - i] , 再减掉前面的数的 hash*(Hina[l - i + 1] - Hina[l - i])
}
sort(t + 1,t + n + 1);
int idx = 1;
for (rint j = 1; j < n; j++)
{
if (t[j] != t[j + 1])
{
idx = 1;
}
else
{
ans += idx;
idx++;
}
}
}
cout << ans << endl;
return 0;
}
最后,我们看一下去年的 CSP-S T3,[CSP-S 2022] 星战
这是我初中最大的遗憾,想到了 hash 之后没敢写,选择了打部分分,最后还不如人家输出个 No 分高。
既然 hash 的思想是给一个字符串相应给出一个数值,为什么我们不能给一个点随机出一个数值呢?我们最后只需要判断它和原来一不一样不就好了??
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
#define int long long
const int N = 5e5 + 5;
int n, m;
int now[N];
int w[N], in[N];
int ans;
int cnt;
signed main()
{
srand(time(0));
scanf("%lld%lld", &n, &m);
for (rint i = 1; i <= n; i++)
{
w[i] = rand();
ans += w[i];
}
for (rint i = 1; i <= m; i++)
{
int u, v;
scanf("%lld%lld", &u, &v);
now[v] += w[u];
in[v] = now[v];
cnt += w[u];
}
int T;
scanf("%lld", &T);
while (T--)
{
int op, u;
scanf("%lld%lld", &op, &u);
if (op == 1)
{
int v;
scanf("%lld", &v);
now[v] -= w[u];
cnt -= w[u];
}
if (op == 2)
{
cnt -= now[u];
now[u] = 0;
}
if (op == 3)
{
int v;
scanf("%lld", &v);
now[v] += w[u];
cnt += w[u];
}
if (op == 4)
{
cnt += in[u] - now[u];
now[u] = in[u];
}
if (cnt == ans)
{
puts("YES");
}
else
{
puts("NO");
}
}
return 0;
}
Trie 字典树
之前因为懒一直没学,后来到了初三听 hs_black 讲的,笑死,根本听不懂.......后来自己扣了很久才整出来。
直接看模板题 AcWing 835. Trie字符串统计
#include <bits/stdc++.h>
#define rint register int
#define int long long
#define endl '\n'
using namespace std;
const int N = 1e5 + 5;
int son[N][26], cnt[N], idx;
char str[N];
/*
son[][]存储子节点的位置,分支最多26条;
cnt[]存储以某节点结尾的字符串个数(同时也起标记作用)
idx表示当前要插入的节点是第几个,每创建一个节点值+1
*/
void insert(char s[])
{
int p = 0;
for (int i = 0; s[i]; i++)
{
int u = s[i] - 'a';
if (!son[p][u]) son[p][u] = ++idx;
//该节点不存在,创建节点
p = son[p][u];
}
cnt[p]++;
//结束时的标记,也是记录以此节点结束的字符串个数
}
int query(char s[])
{
int p = 0;
for (int i = 0; s[i]; i++)
{
int u = str[i] - 'a';
if (!son[p][u]) return 0;
p = son[p][u];
}
return cnt[p];
}
signed main()
{
int T;
cin >> T;
while(T--)
{
char op;
cin >> op;
scanf("%s", str);
if (op == 'I')
{
insert(str);
}
if (op == 'I')
{
cout << query(str) << endl;
}
}
return 0;
}