字符串哈希 学习笔记

字符串哈希指的是用一个函数将一个字符串映射为一个整数。

OI 中常用多项式哈希,即选定基数 \(base\) 和模数 \(mod\),则 \(f(s)=\sum_{i=1}^n s_ibase^{n-i}\pmod{m}\)。可以把这种哈希理解为一个 \(base\) 进制整数对 \(mod\) 取模。当两个串的哈希值相等,说明这两个串大概率相等。

哈希的强大作用是快速判断字符串的子串是否相等。这需要快速求一个字符串的子串的哈希值。

求出字符串的每一个前缀的哈希值,有 \(f(pre_i)=f(pre_{i-1})\cdot base+s_i\)

假设求 \(f(s[l,r])=\sum_{i=l}^{r}s_ibase^{r-i}\),而 \(f(pre_r)=\sum_{i=1}^{r}s_ibase^{r-i},f(pre_{l-1})=\sum_{i=1}^{l-1}s_ibase^{l-i-1}\)。也就是说,\(f(s[l,r])=f(pre_r)-f(pre_{l-1})\cdot base^{r-l+1}\)。预处理 \(base\) 的幂即可。

有了这个功能,哈希可以干很多事。比如下面的代码是字符串匹配:

#include<bits/stdc++.h>
using namespace std;
const int base=131,mod=1000000021;
int n,m,ans,p[1000005],sum[1000005],a;
char s[1000005],t[1000005];
int gethash(int l,int r){
  return (sum[r]-1ll*sum[l-1]*p[r-l+1]%mod+mod)%mod;
} 
int main(){
  p[0]=1,cin>>s+1>>t+1,n=strlen(s+1),m=strlen(t+1);
  for(int i=1;i<=n;i++)p[i]=1ll*p[i-1]*base%mod,sum[i]=(1ll*sum[i-1]*base%mod+s[i])%mod;
  for(int i=1;i<=m;i++)a=(1ll*a*base%mod+t[i])%mod;
  for(int i=1;i<=n-m+1;i++)if(gethash(i,i+m-1)==a)ans++;
  return cout<<ans<<'\n',0;
}

为了方便,常常使用 unsigned long long 存储哈希值,相当于对 \(2^{64}\) 取模,省去了取模的过程。

#include<bits/stdc++.h>
using namespace std;
const unsigned long long base=131;
int n,m,ans;
unsigned long long p[1000005],sum[1000005],a;
char s[1000005],t[1000005];
unsigned long long gethash(int l,int r){
  return sum[r]-sum[l-1]*p[r-l+1];
} 
int main(){
  p[0]=1,cin>>s+1>>t+1,n=strlen(s+1),m=strlen(t+1);
  for(int i=1;i<=n;i++)p[i]=p[i-1]*base,sum[i]=sum[i-1]*base+s[i];\
  for(int i=1;i<=m;i++)a=a*base+t[i];
  for(int i=1;i<=n-m+1;i++)if(gethash(i,i+m-1)==a)ans++;
  return cout<<ans<<'\n',0;
}

如果有 \(n\) 个不同的串,根据生日悖论,当 \(n\) 的级别为 \(O(\sqrt{mod})\) 时,哈希碰撞的概率已经很大。因此,可以选取两组不同的基数或模数分别哈希,判断相等时需要两个哈希值相等,这就将值域扩大到模数的积,错误率可以忽略。

#include<bits/stdc++.h>
using namespace std;
const int base=131,mod1=1000000021,mod2=1000000033;
int n,m,ans,p1[1000005],p2[1000005],sum1[1000005],sum2[1000005],a1,a2;
char s[1000005],t[1000005];
int gethash1(int l,int r){
  return (sum1[r]-1ll*sum1[l-1]*p1[r-l+1]%mod1+mod1)%mod1;
}
int gethash2(int l,int r){
  return (sum2[r]-1ll*sum2[l-1]*p2[r-l+1]%mod2+mod2)%mod2;
}
int main(){
  p1[0]=p2[0]=1,cin>>s+1>>t+1,n=strlen(s+1),m=strlen(t+1);
  for(int i=1;i<=n;i++){
    p1[i]=1ll*p1[i-1]*base%mod1,p2[i]=1ll*p2[i-1]*base%mod2;
    sum1[i]=(1ll*sum1[i-1]*base%mod1+s[i])%mod1,sum2[i]=(1ll*sum2[i-1]*base%mod2+s[i])%mod2;
  }
  for(int i=1;i<=m;i++)a1=(1ll*a1*base%mod1+t[i])%mod1,a2=(1ll*a2*base%mod2+t[i])%mod2;
  for(int i=1;i<=n-m+1;i++)if(gethash1(i,i+m-1)==a1&&gethash2(i,i+m-1)==a2)ans++;
  return cout<<ans<<'\n',0;
}

[[字符串]]

posted @ 2024-03-01 09:38  lgh_2009  阅读(3)  评论(0编辑  收藏  举报