题意简述
给定长度为 \(n\) 的字符串 \(S,T\) ,求有多少个不同的 \(T\) 的子串 \(t\) ,满足 \(t\) 是 \(S\) 的一个子序列。
\(1\le n\le 3000\)
算法分析
子串的个数是 \(\mathcal{O}(n^2)\) 的,子序列的个数是 \(\mathcal{O}(2^n)\) 的,因此考虑枚举所有子串,判断是否是 \(S\) 的子序列。
如何快速判断一个字符串是母串的子序列?直接上子序列自动机就好了。由于枚举过程是增量枚举的,因此总复杂度为 \(\mathcal{O}(n^2\log n)\) 或者 \(\mathcal{O}(n^2+n|\Sigma|)\) 的,取决于子序列自动机的实现方法。
但是我们枚举的子串可能有相同的,需要去重,hash即可,因为字符串总量比较大,用 双模hash
比较保险。
熟悉子序列自动机的可以跳过下面一段:
子算法1 子序列自动机
由名称,不难得出其用途。子序列自动机可以判断一个串是否是母串的子序列。
下设询问串为 \(P\) ,母串为 \(S\) 。
考虑这个询问串在母串上匹配的过程,假设当前询问串的前 \(i\) 位都是母串的子序列,且在母串中匹配到 \(cur\) 。形式化的讲, \(P[1:i]\) 是 \(S\) 的子序列,且 \(P[i]=S[cur]\) 。
现在我们要匹配 \(P[i+1]\) ,如果能匹配上,那么 \(S\) 串在 \(cur\) 位置后一定存在一个位置 \(k\) 能匹配上,即 \(\exists k>cur\ ,\ P[i+1]=S[k]\) 。
但是 \(S\) 串后面可能有若干个合法的 \(k\) ,我们应该取哪一个呢?
我们应该取最靠前的那一个,即 \(k>cur\ ,\ \forall j\in (cur,k]\ ,\ P[i+1]\neq S[j]\) 。
为什么这样的贪心是正确的?因为这个过程有决策包容性。即我们取最靠前的符合要求的 \(k\) ,不会使得答案变差。
后面黄色框表示如果选择 \(k_2\) , \(P\) 串后面可能的一种子序列匹配,在我们选择 \(k_1\) 的时候这种后面的匹配仍然是可达的,因此不会丢失答案。
接下来有两种实现,根据不同情况应选择不同实现方法:
- 记
nxt[i][c]
表示位置 \(i\) 之后第一个为 \(c\) 的字符,记录一个lst[c]
表示当前范围内 \(c\) 最后一次的出现位置,倒序扫描一遍即可。构建时空复杂度为 \(\mathcal{O}(n|\Sigma|)\),查询时间复杂度为 \(\mathcal{O}(|P|)\)。 - 开 \(|\Sigma|\) 个
vector
,存储每一种字符的出现位置,查询的时候二分位置即可,构造时空复杂度为 \(\mathcal{O}(n)\) ,查询时间复杂度为 \(\mathcal{O}(|P|\log n)\)。
一般来说,对于字符集较小,查询量较大的题目,推荐使用第一种写法。对于字符集较大,或者空间较为紧张的题目,推荐使用第二种写法。
实现方法1:
int nxt[maxn][26];//假定为字符集为所有小写字符
int lst[26];
int n;
void build(char *S){
n=strlen(S+1);
for(int j=0;j<26;++j)lst[j]=n+1;
for(int i=n;i>=0;--i){
for(int j=0;j<26;++j)nxt[i][j]=lst[j];
lst[S[i]-'a']=i;
}
}
bool query(char *P){
int cur=0,np=strlen(P+1);
for(int i=1;i<=np;++i){
cur=nxt[cur][P[i]-'a'];
if(cur>n)return 0;
}
return 1;
}
实现方法2:
int n;
vector<int>ps[26];
void build(char *S){
n=strlen(S+1);
for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);
for(int j=0;j<26;++j)ps[j].push_back(n+1);//防止越界,便于处理
}
bool query(char *P){
int cur=0,np=strlen(P+1);
for(int i=1;i<=np;++i){
int nxt=*upper_bound(ps[P[i]-'a'].begin(),ps[P[i]-'a'].end(),cur);
if(nxt>n)return 0;
cur=nxt;
}
return 1;
}
能够正确写出双模HASH的可以跳过下面一段:
子算法2 HASH
可能有很多同学在初学字符串 HASH
的时候写的 HASH
是假的(错误率很高)(包括我自己)。
字符串 HASH
核心思想是把字符串看作一个 BAS
进制数,因为显然存不下,考虑取模,比较常用的 BAS
=\(131,13331\),常用的取模是unsigned long long
自然溢出。
第一个要注意的地方是模数要足够大。由生日悖论, \(\sqrt n\) 个值域为 \([0,n)\) 的数存在相同数的概率超过 \(50\%\) ,如果模数是 int
范围的,则长度为 \(10^5\) 左右的随机字符串已经很容易产生冲突。可参见 Hash Killer II 。
但是我们仅使用自然溢出也会出问题,因为有对着卡的方法,参见 Hash Killer I 。
因此我们通过双底数/双模数的方法处理,具体的,我们取两个不同的BAS
和Mod
,分别计算 HASH
,两个 HASH
都相同才认为是相同的的。
这种方法目前似乎没有很好的方法卡掉,具体可参见 Hash Killer Ⅲ。
处理完错误率的问题,下面来处理效率问题。
先是构建的过程,考虑定义式(可能有多种定义,仅举一例):
为了方便后面计算,还应记录前缀和,即:
暴力计算是 \(\mathcal{O}(n\log n)\) 的,这个过程可以用秦九韶算法优化:
这样避免了快速幂,时间复杂度为 \(\mathcal{O}(n)\) 。
接下来是查询过程。我们查询子串 \(S[l:r]\) ,则对应答案为:
暴力实现是 \(\mathcal{O}(\log n)\) 的,我们预处理出所有的 \(BAS^k\) ,这样复杂度降为 \(\mathcal{O}(1)\) 。
现在的复杂度是线性的,接下来是一些常数优化和一些细节:
- 使用
unsigned long long
自然溢出,减少取模 - 如果是两个数相加/相减,且能保证都在 \([0,Mod)\) 范围内,可以使用减法代替取模
- 底数和模数不能过大,应保证 \(\max\{BAS,Mod\}\times Mod< 2^{62}\),否则在乘法过程中可能会超出
long long/unsigned long long
范围 - 推荐使用
unsigned long long
而不是long long
。尤其注意 自然溢出不能使用long long
,因为long long
的溢出是UB
!
有关本题的一个细节:
由于只有一个询问,去重应使用 sort+unique
或手写哈希表,map/unordered_map
常数巨大,通过此题比较困难。
代码实现
我的代码里采用的是第二种子序列自动机实现方法。
有关 HASH
,我的代码没有完全做到上面的优化,且第二个模数是 int
范围的,有一定的优化空间。
#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define db double
#define ldb long double
#define mod 1000000007
#define eps 1e-9
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
int fh=1;char c=getchar();x=0;
while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int n,m;
char S[maxn],T[maxn];
vector<int>ps[26];
struct HS_node{
ull hs1,hs2;
HS_node operator +(HS_node y)const{
return (HS_node){hs1+y.hs1,(hs2+y.hs2)%mod};
}
HS_node operator -(HS_node y)const{
return (HS_node){hs1-y.hs1,(hs2-y.hs2+mod)%mod};
}
HS_node operator *(HS_node y)const{
return (HS_node){hs1*y.hs1,(hs2*y.hs2)%mod};
}
bool operator <(HS_node y)const{
return hs1==y.hs1?hs2<y.hs2:hs1<y.hs1;
}
bool operator ==(HS_node y)const{
return hs1==y.hs1&&hs2==y.hs2;
}
};
struct MY_Hash{
const ull Bas1=131,Bas2=13331;
HS_node pw[maxn],sh[maxn];
void build(const char *str){//构建hash
int nn=strlen(str+1);
pw[0]=(HS_node){1,1};
for(int i=1;i<=nn;++i)pw[i]=pw[i-1]*(HS_node){Bas1,Bas2};
for(int i=1;i<=nn;++i)sh[i]=sh[i-1]*(HS_node){Bas1,Bas2}+(HS_node){str[i],str[i]};
}
HS_node get_hash(int l,int r){
return sh[r]-sh[l-1]*pw[r-l+1];
}
}hh;
HS_node aa[9000005];
int ans;
signed main(){
#ifndef local
file("block");
#endif
read(n);
scanf("%s",S+1);
scanf("%s",T+1);
hh.build(T);
for(int i=1;i<=n;++i)ps[S[i]-'a'].push_back(i);//子序列自动机构建
for(int i=0;i<26;++i)ps[i].push_back(n+1);//防止超出边界,push一个终止符
for(int i=1;i<=n;++i){
int cur=0;
for(int j=i;j<=n;++j){
int nxt=*upper_bound(ps[T[j]-'a'].begin(),ps[T[j]-'a'].end(),cur);//子序列自动机的转移
if(nxt>n)break;
aa[++m]=hh.get_hash(i,j);
cur=nxt;
}
}
sort(aa+1,aa+m+1);
ans=unique(aa+1,aa+m+1)-aa-1;//去重
printf("%d\n",ans);
fclose(stdin);
fclose(stdout);
return 0;
}