HNOI2008 GT考试 (KMP + 矩阵乘法)
这道题目的题意描述,通俗一点说就是这样:有一个长度为n的数字串(其中每一位都可以是0到9之间任意一个数字),给定一个长度为m的模式串,求有多少种情况,使得此模式串不为数字串的任意一个子串。结果对给定的模数取模。
我们为了阅读方便,将数字串称为P串,给定的模式串称为T串。
一开始有这么个暴力想法,就是直接把T串往P串里面匹配,算出有多少种不合法的情况再计算,不过这样并不行,因为在这种算法中有很多种不合法情况被重复计算了。
于是乎看了题解(看题解也看不懂的我)。我们使用dp[i][j]表示在P串中枚举到第i位时,P串有长为j的后缀与T串中长为j的前缀相匹配的情况数。
注意指的是P串枚举位置之前。这样的话,答案就是dp[n][0] + dp[n][1] + …… dp[n][m-1].
也许有疑问,为什么一定要把枚举到P串末尾的情况才算作答案呢?如果在中间随意匹配几个数不可以吗?亦或者,如果中间有不合法情况呢?
我们注意一下转移方式,肯定dp[i][j]是从dp[i-1][j]转移过来的。转移过来的状态肯定是合法的,所以也就保证了枚举到P串第i位时,所有的情况都能保证P串前面没有不合法情况,即完全与T串匹配的情况。所以前面第一种所说的情况被包含,第二种保证不会存在。(具体可以看下面)
既然如此,那我们说说怎么转移,上面说过,枚举到第i位时的情况必然是由枚举到第i-1位时的情况转移过来的。而对于正在被枚举的第i位,可以填入0~9之中任意一个数字,填入之后对于当前匹配的前后缀长度(就是j)有以下三种影响:
1.匹配长度变为0
2.匹配长度在原来的基础上+1
3.匹配长度变为一个在0~原匹配长度之间的数
我们看到这里突然发现,这里和KMP的过程十分相似。KMP就是在逐位枚举的过程中不断确定在这一位字符之前的最长可匹配的前后缀长度,这里也一样。
所以我们只需要先求一遍T串的next数组,之后在DP的时候使用next计算转移方式即可。这里注意一下,虽然分为三种情况,不过实际上写代码的时候直接写一种处理就可以,因为第1,3中情况是失配时候递归调用next的,而第二种情况是匹配上的,直接写在一起就可以啦。
讲了这么多还没说转移方程……通过上面的思路可以知道,转移方程为:dp[i][j] = dp[i-1][0] * a[o][j] + dp[i-1][1] * a[1][j] + …… + dp[i-1][m-1] * a[m-1][j];
其中a[i][j]表示当前匹配长度由i变为j有多少种方法,这个是固定的数目,可以预处理出来。具体的预处理方法就是按照上面的方法,枚举当前匹配长度,从0~m-1,再枚举当前的数是多少(0~9),比如你现在枚举当前匹配长度为i,如果下一位与T串中对应位置匹配,那么a[i][i+1]++,否则令i = next[i],重复上述过程即可。(很像KMP)这样就顺便可以解释为什么没有不合法情况,因为,我们只枚举匹配长度在0~m-1之间的情况,之后的情况我们不枚举,这样自然不会有不合法情况产生。
可能有人会觉得奇怪,比如你从长度为2不可能直接变到长度为5,6……,那么这些情况自然就是0了,还有就是有人会疑问每一位的数字可能不同,怎么方法数是固定的,因为这是你已经枚举的,枚举的时候是会自动考虑所有情况的,可能有点难理解,不过想一想意会一下还是可以的。
DP方程已经好了,不过还没完,n的范围到了10^9,这样O(n)是承受不起的。我们需要使用某些手段加速。重新观察一下DP方程,既然a数组是一个固定的值……
那就可以使用矩阵乘法优化。(因为每一步,dp[i][0~m-1]都是由dp[i-1][0~m-1]通过固定方式转化过来的,想想Fibonacci!)
我们偷一下某位大神的图来说明这件事。
这样就很一目了然了。调用一下矩阵快速幂,直接求a的n次幂,并与一开始构造的单位矩阵相乘即可。(一开始很明显只有f[i][i]是1,其余都是0)
(唉我真是学啥忘啥,这KMP和矩阵乘法都快不会了)
看一下代码。
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<queue> #include<cstring> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar('\n') using namespace std; const int maxn = 30; int n,mod,len,nxt[maxn],res; char str[maxn]; struct matrix { int f[maxn][maxn]; matrix() { memset(f,0,sizeof(f)); } friend matrix operator * (const matrix &a,const matrix &b) { matrix c; rep(i,0,len-1) rep(j,0,len-1) rep(k,0,len-1) c.f[i][j] += a.f[i][k] * b.f[k][j],c.f[i][j] %= mod; return c; } }m,ans; void getnext() { int j = 0; nxt[1] = 0; rep(i,2,len) { while(j && str[j+1] != str[i]) j = nxt[j];//这种求next是从0开始的,从-1也可以 if(str[j+1] == str[i]) j++; nxt[i] = j; } // rep(i,1,len) printf("%d ",nxt[i]); } void dp() { rep(i,0,len-1) rep(j,0,9) { int k; for(k = i; k; k = nxt[k]) if(str[k+1] - '0' == j) break; if(str[k+1] - '0' == j) k++; m.f[i][k]++,m.f[i][k] %= mod; } return; } void pow() { rep(i,0,len-1) ans.f[i][i] = 1; //ans是构造的单位矩阵 while(n) { if(n&1) ans = ans * m; m = m * m; n >>= 1; } return; } int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >='0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } int main() { n = read(),len = read(),mod = read(); scanf("%s",str+1); getnext(),dp(),pow(); rep(i,0,len-1) res = (res + ans.f[0][i]) % mod; printf("%d\n",res); return 0; } /* 4 3 100 111 */