1282. 搜索关键词
题目链接
1282. 搜索关键词
给定 \(n\) 个长度不超过 \(50\) 的由小写英文字母组成的单词,以及一篇长为 \(m\) 的文章。
请问,其中有多少个单词在文章中出现了。
注意:每个单词不论在文章中出现多少次,仅累计 \(1\) 次。
输入格式
第一行包含整数 \(T\),表示共有 \(T\) 组测试数据。
对于每组数据,第一行一个整数 \(n\),接下去 \(n\) 行表示 \(n\) 个单词,最后一行输入一个字符串,表示文章。
输出格式
对于每组数据,输出一个占一行的整数,表示有多少个单词在文章中出现。
数据范围
\(1≤n≤10^4,\)
\(1≤m≤10^6\)
输入样例:
1
5
she
he
say
shr
her
yasherhs
输出样例:
3
解题思路
ac自动机
ac自动机=字典树+KMP
将所有匹配串放入字典树中,关键在于 \(next[i]\) 数组:表示某个点开始与到节点 \(i\) 形成的字符串相等的最长前缀的那个节点,其写法类似于KMP的写法,即某个节点匹配不上时往前匹配直到匹配成功或到根节点。在搜索母串的过程中也是这样,即将所有的模式串与母串进行匹配,当节点存在时统计答案即可
优化:
trie图
本质上ac自动机在匹配的过程中失败的话会往前走,trie图的目的就是让每个节点都直接指向它最后的位置
- 时间复杂度:\(O(50n)\)
代码
// Problem: 搜索关键词
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/1284/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=1e4+5,M=1e6+5;
int t,n,trie[N*50][26],idx,cnt[N*50],q[N*50],ne[50*N];
char s[55],str[M];
bool v[N*50];
void insert()
{
int p=0;
for(int i=0;s[i];i++)
{
int t=s[i]-'a';
if(!trie[p][t])trie[p][t]=++idx;
p=trie[p][t];
}
cnt[p]++;
}
void build()
{
int hh=0,tt=-1;
for(int i=0;i<26;i++)
if(trie[0][i])q[++tt]=trie[0][i];
while(hh<=tt)
{
int t=q[hh++];
for(int i=0;i<26;i++)
{
int c=trie[t][i];
if(!c)continue;
int j=ne[t];
while(j&&!trie[j][i])j=ne[j];
if(trie[j][i])j=trie[j][i];
ne[c]=j;
q[++tt]=c;
}
}
}
int main()
{
for(cin>>t;t;t--)
{
idx=0;
memset(v,0,sizeof v);
memset(ne,0,sizeof ne);
memset(trie,0,sizeof trie);
memset(cnt,0,sizeof cnt);
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>s;
insert();
}
build();
cin>>str;
int res=0;
for(int i=0,j=0;str[i];i++)
{
int t=str[i]-'a';
while(j&&!trie[j][t])j=ne[j];
if(trie[j][t])j=trie[j][t];
int p=j;
while(p)
{
if(v[p])break;
v[p]=true;
res+=cnt[p];
p=ne[p];
}
}
cout<<res<<'\n';
}
return 0;
}
- trie图优化
// Problem: 搜索关键词
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/1284/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=1e4+5,M=1e6+5;
int t,n,trie[N*50][26],idx,cnt[N*50],q[N*50],ne[50*N];
char s[55],str[M];
bool v[N*50];
void insert()
{
int p=0;
for(int i=0;s[i];i++)
{
int t=s[i]-'a';
if(!trie[p][t])trie[p][t]=++idx;
p=trie[p][t];
}
cnt[p]++;
}
void build()
{
int hh=0,tt=-1;
for(int i=0;i<26;i++)
if(trie[0][i])q[++tt]=trie[0][i];
while(hh<=tt)
{
int t=q[hh++];
for(int i=0;i<26;i++)
{
int p=trie[t][i];
if(!p)trie[t][i]=trie[ne[t]][i];
else
{
ne[p]=trie[ne[t]][i];
q[++tt]=p;
}
}
}
}
int main()
{
for(cin>>t;t;t--)
{
idx=0;
memset(v,0,sizeof v);
memset(ne,0,sizeof ne);
memset(trie,0,sizeof trie);
memset(cnt,0,sizeof cnt);
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>s;
insert();
}
build();
cin>>str;
int res=0;
for(int i=0,j=0;str[i];i++)
{
int t=str[i]-'a';
j=trie[j][t];
int p=j;
while(p)
{
if(v[p])break;
v[p]=true;
res+=cnt[p];
p=ne[p];
}
}
cout<<res<<'\n';
}
return 0;
}