bzoj1009[HNOI2008]GT考试
Description
阿申准备报名参加GT考试,准考证号为N位数X1X2....Xn(0<=Xi<=9),他不希望准考证号上出现不吉利的数字。
他的不吉利数学A1A2...Am(0<=Ai<=9)有M位,不出现是指X1X2...Xn中没有恰好一段等于A1A2...Am. A1和X1可以为
0
Input
第一行输入N,M,K.接下来一行输入M位的数。 N<=10^9,M<=20,K<=1000
Output
阿申想知道不出现不吉利数字的号码有多少种,输出模K取余的结果.
Sample Input
4 3 100
111
111
Sample Output
81
其实不是很难。。。
首先我们可以想到一个简单的dp:
f[i][j] 表示准考证号填完第i位不吉利数字与准考证号(后缀)最多匹配到第j位
我们可以枚举第i 位填入的数字, 然后用不吉利数字的kmp来转移
如果 f[i][j] 可以从 f[i - 1][k] 转移过来
必然是不吉利数字在第k + 1位失配并且不断查找nxt 知道第 j 位重新匹配
之后就可以矩乘加速了
1 #include <iostream> 2 #include <cstdio> 3 #include <algorithm> 4 #include <cstdio> 5 #include <cstring> 6 #define LL long long 7 8 using namespace std; 9 10 int N, M, K; 11 int nxt[30]; 12 char str[30]; 13 LL sum = 0; 14 int s[30]; 15 struct mat { 16 int a[30][30]; 17 void init() 18 { 19 memset(a, 0, sizeof a); 20 } 21 } tran, temp, iit, cnt, res, ans; 22 inline LL read() 23 { 24 LL x = 0, w = 1; char ch = 0; 25 while(ch < '0' || ch > '9') { 26 if(ch == '-') { 27 w = -1; 28 } 29 ch = getchar(); 30 } 31 while(ch >= '0' && ch <= '9') { 32 x = x * 10 + ch - '0'; 33 ch = getchar(); 34 } 35 return x * w; 36 } 37 38 void init() 39 { 40 nxt[1] = 0; 41 for(int i = 2; i <= M; i++) { 42 int j = nxt[i - 1]; 43 while(j && s[i] != s[j + 1]) { 44 j = nxt[j]; 45 } 46 if(s[i] == s[j + 1]) { 47 j++; 48 } 49 nxt[i] = j; 50 } 51 iit.init(); 52 for(int i = 0; i <= M; i++) { 53 iit.a[i][i] = 1; 54 } 55 } 56 57 mat operator *(mat a, mat b) 58 { 59 temp.init(); 60 for(int i = 0; i <= M; i++) { 61 for(int j = 0; j <= M; j++) { 62 for(int k = 0; k <= M; k++) { 63 temp.a[i][j] = (temp.a[i][j] + a.a[i][k] * b.a[k][j]) % K; 64 } 65 } 66 } 67 return temp; 68 } 69 70 mat fast(mat x, int k) 71 { 72 cnt = x; 73 res = iit; 74 while(k) { 75 if(k % 2 == 1) { 76 res = res * cnt; 77 k--; 78 } 79 cnt = cnt * cnt; 80 k = k / 2; 81 } 82 return res; 83 } 84 85 void print(mat a) 86 { 87 for(int i = 0; i <= M; i++) { 88 for(int j = 0; j <= M; j++) { 89 cout<<a.a[i][j]<<" "; 90 } 91 cout<<endl; 92 } 93 cout<<endl; 94 } 95 int main() 96 { 97 N = read(), M = read(), K = read(); 98 scanf("%s", str + 1); 99 for(int i = 1; i <= M; i++) { 100 s[i] = str[i] - '0'; 101 } 102 init(); 103 for(int k = 0; k <= 9; k++) { 104 for(int i = 0; i < M; i++) { 105 int j = i; 106 while(j && s[j + 1] != k) { 107 j = nxt[j]; 108 } 109 if(s[j + 1] == k) { 110 j++; 111 } 112 tran.a[i][j]++; 113 } 114 } 115 // print(tran); 116 ans = fast(tran, N); 117 // print(ans); 118 for(int i = 0; i < M; i++) { 119 sum = (sum + ans.a[0][i]) % K; 120 } 121 printf("%lld\n", sum); 122 return 0; 123 } 124 125 /* 126 4 3 100 127 128 111 129 */