BZOJ 1009 [HNOI2008]GT考试(矩阵快速幂优化DP+KMP)
题意:
求长度为n的不含长为m的指定子串的字符串的个数
1s, n<=1e9, m<=50
思路:
长见识了。。
设那个指定子串为s
f[i][j]表示长度为i的字符串(其中后j个字符与s的前j个字符一致的情况下)的方法数
若匹配到s串长度为i的后缀加一个字符num可以组成最长长度为j的后缀,设a[i][j]为num的方法数
例如,s为12312,a为
9 1 0 0 0 0
8 1 1 0 0 0
8 1 0 1 0 0
9 0 0 0 1 0
8 1 0 0 0 1
(i,j都是从0到m-1)
如a[1][2]表示从“1”到“12”可以加的字符方法数,显然加“2”才可以,所以a[1][2]=1
而a[2][0]表示从“12”到“”可以加的字符方法数:显然不能加“3”,不然s串会匹配到"123";也不能加“1”,不然s串会匹配成"1"。所以a[2][0]=8
求a矩阵的方法是kmp,感觉只可意会(我写不出来QAQ)
显然f[i][x]只能由f[i-1][k]转移而来,而k为多少,要看a数组了
然后状态转移方程为:$f[i][j] = f[i-1][0]*a[0][j]+f[i-1][1]*a[1][j] +\dots + f[i-1][m-1]*a[m-1][j]$
这个状态转移方程可以用矩阵快速幂来加速
答案就是$\sum f[n][i]$
代码:
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #define fst first #define sc second #define pb push_back #define mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const db eps = 1e-6; //const int mod = 1e9+7; const int maxn = 2e3+100; const int maxm = 2e6+100; const int inf = 0x3f3f3f3f; const db pi = acos(-1.0); int a[60][60]; int f[60][60]; int n, m, mod; char s[maxn]; int Next[maxn]; void mtpl(int a[60][60], int b[60][60], int s[60][60]){ int tmp[60][60]; for(int i = 0; i < m; i++){ for(int j = 0; j < m; j++){ tmp[i][j] = 0; for(int k = 0; k < m; k++){ tmp[i][j]+=a[i][k]*b[k][j]%mod; tmp[i][j]%=mod; } } } for(int i = 0; i < m; i++){ for(int j = 0; j < m; j++){ s[i][j] = tmp[i][j]; } } return; } void fp(int x){ while(x){ if(x&1)mtpl(f,a,f); mtpl(a,a,a); x>>=1; } return; } void kmp(){ int fix = 0; for(int i = 2; i <= m; i++){ while(fix && s[fix+1]!=s[i])fix=Next[fix]; if(s[fix+1]==s[i])++fix; Next[i]=fix; } for(int i = 0; i < m; i++){ for(char j = '0'; j <= '9'; j++){ fix = i; while(fix&&s[fix+1]!=j)fix=Next[fix]; if(j==s[fix+1])a[i][fix+1]++; else a[i][0]++; } } return; } int main(){ scanf("%d %d %d", &n, &m, &mod); scanf("%s", s+1); mem(a, 0); kmp(); mem(f,0); f[0][0]=1; fp(n); int ans = 0; for(int i = 0; i < m; i++){ ans += f[0][i]; ans%=mod; } printf("%d", ans); return 0; } /* 5 3 4 5 1 2 */