HDU4758 Walk Through Squares AC自动机&&dp

这道题当时做的时候觉得是数论题,包含两个01串什么的,但是算重复的时候又很蛋疼,赛后听说是字符串,然后就觉得很有可能。昨天队友问到这一题,在学了AC自动机之后就觉得简单了许多。那个时候不懂AC自动机,不知道什么是状态,因此没有想到有效的dp方法。

题意是这样的,给定两个RD串,譬如RRD,DDR这样子的串,然后现在要你向右走(R)m步,向下走(D)n步,问有多少种走法能够包含给定的两个串。

一个传统的dp思想是这样的 dp[i][j][x][y][k],表示走了i步R,j步D,x,y表示两个串各匹配了多少各,k表示的是1,2串匹配的一个4进制数(00,01,10,11,你懂的,11表示都匹配了,10表示匹配了1串)。 但是这样一来空间开不下,二来当某个点失配的时候我们不知道当前的x,y会转移到哪里,这个时候很自然的,我们就想到了AC自动机,AC自动机压入两个串只需要不超过串的总长度的结点,而且当我们在自动机上转移的时候,我们可以知道失配的时候转移到哪里。所以重新定义一下就是 dp[i][j][k][x] k表示自动机上的状态,x表示4进制数。转移的时候就考虑由当前的状态dp[i][j][k][x]转移到dp[i+1][j][nxt1][nxtx] dp[i][j+1].... 其中新的状态nxt以及对应的四进制数转移就需要根据AC自动机的失配算出来。 如果预处理出当失配时回到的那个结点感觉可能会更快一些。 我代码里多写了个dfs,主要是预处理了 到达改状态时对应的四进制数,所以转移的时候只需要或一下就可以了。

第一次做AC自动机上的dp然后 1A了,好开心!

#pragma warning(disable:4996)
#include<iostream>
#include<cstring>
#include<string>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cmath>
#include<queue>
#define maxn 2000
#define mod 1000000007
using namespace std;

struct Trie
{
	Trie * go[2];
	Trie *fail;
	int sta;
	void init() { 
		memset(go, 0, sizeof(go)), fail == NULL; 
		sta = 0;
	}
}pool[maxn],*root;
int tot;

void insert(char *c,int type)
{
	int len = strlen(c); Trie *p = root;
	for (int i = 0; i < len; i++){
		int ind = c[i] == 'R' ? 0 : 1;
		if (p->go[ind] != NULL){
			p = p->go[ind];
		}
		else{
			pool[tot].init();
			p->go[ind] = &pool[tot++];
			p = p->go[ind];
		}
	}
	p->sta |= type;
}

void getFail()
{
	queue<Trie*> que;
	que.push(root);
	root->fail = NULL;
	while (!que.empty())
	{
		Trie *temp = que.front(); que.pop();
		Trie *p = NULL;
		for (int i = 0; i < 2; i++){
			if (temp->go[i] != NULL){
				if (temp == root) temp->go[i]->fail = root;
				else{
					p = temp->fail;
					while (p != NULL){
						if (p->go[i] != NULL){
							temp->go[i]->fail = p->go[i];
							break;
						}
						p = p->fail;
					}
					if (p == NULL) temp->go[i]->fail = root;
				}
				que.push(temp->go[i]);
			}
		}
	}
}

int dfs(Trie *x){
	if (x == NULL) return 0;
	return x->sta |= dfs(x->fail);
}

int m, n;
int dp[120][120][240][4];
char str[120];
int main()
{
	int T; cin >> T;
	while (T--)
	{

		tot = 0; root = &pool[tot++]; root->init();
		scanf("%d%d", &m, &n);
		for (int i = 1; i <= 2; i++){
			scanf("%s", str);
			insert(str, i);
		}
		getFail();
		for (int i = 0; i < tot; i++){
			dfs(&pool[i]);
		}
		for (int i = 0; i <= m; i++){
			for (int j = 0; j <= n; j++){
				for (int k = 0; k <= tot; k++){
					for (int x = 0; x < 4; x++){
						dp[i][j][k][x] = 0;
					}
				}
			}
		}
		dp[0][0][0][0] = 1;
		for (int i = 0; i <= m; i++){
			for (int j = 0; j <= n; j++){
				for (int k = 0; k < tot; k++){
					for (int x = 0; x < 4; x++){
						Trie* p = &pool[k];
						if (p->go[0] != NULL){
							(dp[i + 1][j][p->go[0] - pool][x | p->go[0]->sta] += dp[i][j][k][x]) %= mod;
						}
						else{
							Trie *temp = p->fail;
							while (temp != NULL) {
								if (temp->go[0] != NULL){
									(dp[i + 1][j][temp->go[0] - pool][x | temp->go[0]->sta] += dp[i][j][k][x]) %= mod;
									break;
								}
								temp = temp->fail;
							}
							if (temp == NULL) (dp[i + 1][j][0][x | root->sta] += dp[i][j][k][x]) %= mod;
						}
						if (p->go[1] != NULL){
							(dp[i][j + 1][p->go[1] - pool][x | p->go[1]->sta] += dp[i][j][k][x]) %= mod;
						}
						else{
							Trie *temp = p->fail;
							while (temp != NULL) {
								if (temp->go[1] != NULL){
									(dp[i][j + 1][temp->go[1] - pool][x | temp->go[1]->sta] += dp[i][j][k][x]) %= mod;
									break;
								}
								temp = temp->fail;
							}
							if (temp == NULL) (dp[i][j + 1][0][x | root->sta] += dp[i][j][k][x]) %= mod;
						}
					}
				}
			}
		}
		int ans = 0;
		for (int i = 0; i < tot; i++){
			ans = ans + dp[m][n][i][3]; ans %= mod;
		}
		printf("%d\n", ans);
	}
	return 0;
}

 

posted @ 2014-02-26 16:59  chanme  阅读(409)  评论(0编辑  收藏  举报