BZOJ1009GT考试 DP + KMP + 矩陣快速冪

@[DP, KMP, 矩陣快速冪]

Description

阿申准备报名参加GT考试,准考证号为\(N\)位数\(X_1 X_2 .. X_n(0 <= X_i <= 9)\),他不希望准考证号上出现不吉利的数字。
他的不吉利数学\(A_1 A_2 .. A_m (0 <= A_i <= 9)\)有M位,不出现是指\(X_1 X_2 .. X_n\)中没有恰好一段等于\(A_1 A_2 .. A_m\). \(A_1\)\(X_1\)可以为\(0\)

Input

第一行输入\(N,M,K\).接下来一行输入\(M\)位的数。 \(N<=10^9,M<=20,K<=1000\)

Output

阿申想知道不出现不吉利数字的号码有多少种,输出模K取余的结果.

Sample Input

4 3 100
111

Sample Output

81

Solution

很容易想到DP方程

\[f[i][j] = \sum_{k = 0}^{m - 1}f[i - 1][k] * trans[k][j] \]

\[ans = \sum_{j = 0}^{m - 1}f[n][j] \]

其中, \(f[i][j]\)表示考號第\(i\)位匹配到不吉利串第\(j\)位時的情況數; \(trans[k][j]\)記錄上一位位匹配至不吉利串中的第\(k\)位時, 填入\(num \in [1, 10)\)使得當前位匹配至不吉利串第\(j\)位的\(num\)數(實際上這個數量只能是\(1\)或者\(9\))
然後就會發現, \(i\)最大可以達到\(10^{9}\), 因此時間複雜度必須要優化.
想到矩陣快速冪, 發現可以直接將\(f\)整個省略掉, 只要求出\(trans^{n}\)即可
至於\(trans\)數組, 通過KMP算法與處理一下就好了
然後就可以直接看代碼了

#include<cstdio>
#include<cstring>
using namespace std;
const int M = 1 << 5;
int n, m, K;
char a[M];
int pre[M];
int trans[M][M];
int ans[M][M];
void mul(int a[M][M], int b[M][M], int res[M][M])
{
	int tmp[M][M];
	for(int i = 0; i < m; i ++)
		for(int j = 0; j < m; j ++)
		{
			tmp[i][j] = 0;
			for(int k = 0; k < m; k ++)
				tmp[i][j] = (tmp[i][j] + a[i][k] * b[k][j]) % K;	
		}
	for(int i = 0; i < m; i ++)
		for(int j = 0; j < m; j ++)
			res[i][j] = tmp[i][j];
}
int main()
{
	#ifndef ONLINE_JUDGE
	freopen("BZOJ1009.in", "r", stdin);
	freopen("BZOJ1009.out", "w", stdout);
	#endif
	scanf("%d%d%d", &n, &m, &K);
	scanf("%s", a + 1);
	for(int i = 1; i <= m; i ++)
		*(a + i) -= '0';
	pre[1] = 0;
	for(int i = 2; i <= m; i ++)
	{
		int p = pre[i - 1];
		while(p && (a[p + 1] != a[i]))
			p = pre[p];
		pre[i] = ((a[p + 1] == a[i]) ? (p + 1) : p);
	}
	memset(trans, 0, sizeof(trans));
	for(int i = 0; i < m; i ++)
		for(int j = 0; j < 10; j ++)
		{
			int p = i;
			while(p && (a[p + 1] != j))
				p = pre[p];
			if(a[p + 1] == j)
				p ++;
			trans[p][i] = (trans[p][i] + 1) % K;
		}
	memset(ans, 0, sizeof(ans));
	for(int i = 0; i < m; i ++)
		ans[i][i] = 1;
	while(n)
	{
		if(n & 1)
			mul(ans, trans, ans);
		mul(trans, trans, trans);
		n >>= 1;
	}
	int sum = 0;
	for(int i = 0; i < m; i ++)
		sum = (sum + ans[i][0]) % K;
	printf("%d", sum); 
} 
posted @ 2017-03-01 10:29  Zeonfai  阅读(465)  评论(0编辑  收藏  举报