【YBTOJ】【AC自动机】组合攻击
组合攻击
小明在玩一款游戏,该游戏只有三个技能键 \(\texttt{A}\) , \(\texttt{B}\) , \(\texttt{C}\) 可用,但这些键可用形成 \(n\) 种特定的组合技。第 \(i\) 个组合技用一个字符串 \(s_i\) 表示。
小明会输入一个长度为 \(k\) 的字符串 ,而一个组合技每在 \(t\) 中出现一次,小明就会获得一分。 \(s_i\) 在 \(t\) 中出现一次指的是 \(s_i\) 是 \(t_i\) 从某个位置起的连续子串。如果 \(s_i\) 从 \(t\) 的多个位置起都是连续子串,那么算作 \(s_i\) 出现了多次。
若小明输入了恰好 \(k\) 个字符,则他最多能获得多少分。
\(n \leq 20\) , \(k\leq10^3\) , \(1\leq|s_i|\leq15\)。
题解
AC自动机 + DP 模板。
先根据每个组合技建立出 AC自动机的trie图 ,然后根据此自动机进行 dp 。
设: \(dp(i,j)\) 表示,当前填入第 \(i\) 位,填入的字符在 trie树 中的点编号是 \(j\) ,能得到的最大答案。
那么有 dp 式:
\[dp(i,c_{j,k}) = \max\{dp(i-1,j)+val(c_{j,k})\}
\]
其中 \(val(x)\) 表示 在 trie树 上匹配到第 \(x\) 号节点,有多少个字符串结束(对答案的贡献)。
这里注意:在 BFS 的过程中,要 $val(p) \leftarrow val(p) +val(fail_p) $ , 因为匹配到 \(p\) 点的同时必定匹配到 \(fail_p\)。
代码
#include <bits/stdc++.h>
#define fo(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
const int INF = 0x3f3f3f3f,N = 22,M = 17,L = 1e3+5;
typedef long long ll;
typedef unsigned long long ull;
inline ll read(){
ll ret=0;char ch=' ',c=getchar();
while(!(c>='0'&&c<='9'))ch=c,c=getchar();
while(c>='0'&&c<='9')ret=(ret<<1)+(ret<<3)+c-'0',c=getchar();
return ch=='-'?-ret:ret;
}
int n,m;
template <typename T> struct que{
T a[N*M]; int st=1,ed=0;
que(){st=1,ed=0;}
inline void clear(){st=1,ed=0;}
inline int size(){return ed-st+1;}
inline bool empty(){return !(ed-st+1);}
inline T front(){return a[st];}
inline T back(){return a[ed];}
inline void pop_front(){st++;}
inline void pop_back(){ed--;}
inline void push(T x){a[++ed] = x;}
inline T operator [] (int x){return a[st+x-1];}
};
int dp[L][N*M];
struct ACauto{
int c[N*M][5],fail[N*M],tot = 1;
int w[N*M];
que <int> q;
inline void insert(char ch[]){
int p = 1 , len = strlen(ch+1);
for(int i = 1 ; i <= len ; i ++){
int v = ch[i]-'A'+1;
if(!c[p][v]) c[p][v] = ++tot;
p = c[p][v];
}
w[p] ++;
}
void build(){
for(int i = 1 ; i <= 3 ; i ++) c[0][i] = 1;
q.push(1);
while(!q.empty()){
int p = q.front(); q.pop_front();
w[p] += w[fail[p]];
for(int i = 1 ; i <= 3 ; i ++)
if(c[p][i]) fail[c[p][i]] = c[fail[p]][i], q.push(c[p][i]);
else c[p][i] = c[fail[p]][i];
}
}
int Dp(){
for(int i = 1 ; i <= m ; i ++)
for(int j = 1 ; j <= tot ; j ++)
for(int k = 1 ; k <= 3 ; k ++)
dp[i][c[j][k]] = max(dp[i][c[j][k]],dp[i-1][j] + w[c[j][k]]);
// printf(" update [%d][%d] from [%d,%d](%d) + [%d](%d) -> (%d)\n",i,c[j][k],i-1,j,dp[i-1][j],c[j][k],w[c[j][k]],dp[i][c[j][k]]);
int ret = 0;
for(int i = 1 ; i <= tot ; i ++)
ret = max(ret,dp[m][i]);
return ret;
}
}ac;
char ch[M];
signed main(){
n = read(), m = read();
for(int i = 1 ; i <= n ; i ++)
scanf("%s",ch+1),
ac.insert(ch);
ac.build();
memset(dp,0xc0,sizeof(dp));
for(int i = 0 ; i <= m ; i ++) dp[i][1] = 0;
printf("%d",ac.Dp());
return 0;
}