【YBTOJ】【AC自动机】组合攻击

组合攻击

小明在玩一款游戏,该游戏只有三个技能键 \(\texttt{A}\)\(\texttt{B}\)\(\texttt{C}\) 可用,但这些键可用形成 \(n\) 种特定的组合技。第 \(i\) 个组合技用一个字符串 \(s_i\) 表示。

小明会输入一个长度为 \(k\) 的字符串 ,而一个组合技每在 \(t\) 中出现一次,小明就会获得一分。 \(s_i\)\(t\) 中出现一次指的是 \(s_i\)\(t_i\) 从某个位置起的连续子串。如果 \(s_i\)\(t\) 的多个位置起都是连续子串,那么算作 \(s_i\) 出现了多次。

若小明输入了恰好 \(k\) 个字符,则他最多能获得多少分。

\(n \leq 20\)\(k\leq10^3\)\(1\leq|s_i|\leq15\)

题解

AC自动机 + DP 模板。

先根据每个组合技建立出 AC自动机的trie图 ,然后根据此自动机进行 dp 。

设: \(dp(i,j)\) 表示,当前填入第 \(i\) 位,填入的字符在 trie树 中的点编号是 \(j\) ,能得到的最大答案。

那么有 dp 式:

\[dp(i,c_{j,k}) = \max\{dp(i-1,j)+val(c_{j,k})\} \]

其中 \(val(x)\) 表示 在 trie树 上匹配到第 \(x\) 号节点,有多少个字符串结束(对答案的贡献)。

这里注意:在 BFS 的过程中,要 $val(p) \leftarrow val(p) +val(fail_p) $ , 因为匹配到 \(p\) 点的同时必定匹配到 \(fail_p\)

代码

#include <bits/stdc++.h>
#define fo(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
const int INF = 0x3f3f3f3f,N = 22,M = 17,L = 1e3+5;
typedef long long ll;
typedef unsigned long long ull;
inline ll read(){
	ll ret=0;char ch=' ',c=getchar();
	while(!(c>='0'&&c<='9'))ch=c,c=getchar();
	while(c>='0'&&c<='9')ret=(ret<<1)+(ret<<3)+c-'0',c=getchar();
	return ch=='-'?-ret:ret;
}
int n,m;
template <typename T> struct que{
	T a[N*M]; int st=1,ed=0;
	que(){st=1,ed=0;}
	inline void clear(){st=1,ed=0;}
	inline int size(){return ed-st+1;}
	inline bool empty(){return !(ed-st+1);}
	inline T front(){return a[st];}
	inline T back(){return a[ed];}
	inline void pop_front(){st++;}
	inline void pop_back(){ed--;}
	inline void push(T x){a[++ed] = x;}
	inline T operator [] (int x){return a[st+x-1];}
};
int dp[L][N*M];
struct ACauto{
	int c[N*M][5],fail[N*M],tot = 1;
	int w[N*M];
	que <int> q;
	inline void insert(char ch[]){
		int p = 1 , len = strlen(ch+1);
		for(int i = 1 ; i <= len ; i ++){
			int v = ch[i]-'A'+1;
			if(!c[p][v]) c[p][v] = ++tot;
			p = c[p][v];
		} 
		w[p] ++;
	}
	void build(){
		for(int i = 1 ; i <= 3 ; i ++) c[0][i] = 1;
		q.push(1);
		while(!q.empty()){
			int p = q.front(); q.pop_front();
			w[p] += w[fail[p]];
			for(int i = 1 ; i <= 3 ; i ++)
				if(c[p][i]) fail[c[p][i]] = c[fail[p]][i], q.push(c[p][i]);
				else c[p][i] = c[fail[p]][i];
		}
	}
	int Dp(){
		for(int i = 1 ; i <= m ; i ++)
			for(int j = 1 ; j <= tot ; j ++)
				for(int k = 1 ; k <= 3 ; k ++)
					dp[i][c[j][k]] = max(dp[i][c[j][k]],dp[i-1][j] + w[c[j][k]]);
//					printf(" update [%d][%d]  from [%d,%d](%d) + [%d](%d) -> (%d)\n",i,c[j][k],i-1,j,dp[i-1][j],c[j][k],w[c[j][k]],dp[i][c[j][k]]);
		int ret = 0;
		for(int i = 1 ; i <= tot ; i ++)
			ret = max(ret,dp[m][i]);
		return ret;
	}
}ac;
char ch[M];
signed main(){
	n = read(), m = read();
	for(int i = 1 ; i <= n ; i ++)
		scanf("%s",ch+1),
		ac.insert(ch);
	ac.build();
	memset(dp,0xc0,sizeof(dp));
	for(int i = 0 ; i <= m ; i ++) dp[i][1] = 0;
	printf("%d",ac.Dp());
	return 0;
}
posted @ 2021-09-24 16:08  Last-Order  阅读(117)  评论(0编辑  收藏  举报