[HNOI2008]GT考试
嘟嘟嘟
这道题刚开始我连dp方程都没设出来,现在想一想还是我对dp的理解不够深。
令\(dp[i][j]\)表示长串匹配到第\(i\)位,短串匹配到第\(j\)位时的方案数。因为题中说不让匹配成功,所以答案是\(dp[n][m - 1]\)。
但转移不好写,因为这个状态不够具体。应该在加一个条件:长串\(s\)[\(1\)~\(i\)]的后缀和短串的前缀最长的公共部分为\(j\)。这样转移就好办了。
如果还想不出来,可以想\(dp[i][j]\)能转移到什么状态:
1.匹配成功:\(dp[i][j]\) -> \(dp[i + 1][j + 1]\)
2.匹配不成功:这个时候\(dp[i][j]\) -> \(dp[i + 1][k]\)。这个\(k\)跟\(i + 1\)这个位置填什么字符有关。
也就是说:
\[dp[i][j] = \sum _ {k = 0} ^ {m - 1} dp[i - 1][k] * f[k][j]
\]
这个\(f[k][j]\)表示短串的第\(k\)个位置有多少种方案能转移到\(j\)。由此可见,这个数组跟长串无关。
所以可以先预处理这个数组:用kmp即可。
然后我们就有了一个\(O(nm ^ 2)\)的算法,交上去能得40分。
优化:
看上面的那个转移方程
\[dp[i][j] = \sum _ {k = 0} ^ {m - 1} dp[i - 1][k] * f[k][j]
\]
发现就是一个普通的矩阵乘法。
然后我们矩阵快速幂一下就可以啦。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxm = 25;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, m, mod;
char s[maxm];
struct Mat
{
int a[maxm][maxm];
Mat operator * (const Mat& oth)const
{
Mat ret; Mem(ret.a, 0);
for(int i = 0; i < m; ++i)
for(int j = 0; j < m; ++j)
for(int k = 0; k < m; ++k)
ret.a[i][j] += a[i][k] * oth.a[k][j], ret.a[i][j] %= mod;
return ret;
}
}F;
Mat quickpow(Mat A, int b)
{
Mat ret; Mem(ret.a, 0);
for(int i = 0; i < m; ++i) ret.a[i][i] = 1;
for(; b; b >>= 1, A = A * A)
if(b & 1) ret = ret * A;
return ret;
}
int nxt[maxm];
void kmp()
{
for(int i = 2, j = 0; i <= m; ++i)
{
while(j && s[j + 1] != s[i]) j = nxt[j];
if(s[j + 1] == s[i]) j++; nxt[i] = j;
}
for(int i = 0; i < m; ++i)
for(int j = 0; j <= 9; ++j)
{
int k = i;
while(k && s[k + 1] != j + '0') k = nxt[k];
if(s[k + 1] == j + '0') k++;
if(k < m) F.a[i][k]++;
}
}
int dp[maxn][maxm];
int main()
{
n = read(); m = read(); mod = read();
scanf("%s", s + 1);
Mem(F.a, 0); kmp();
Mat A = quickpow(F, n);
int ans = 0;
for(int i = 0; i < m; ++i) ans = (ans + A.a[0][i]) % mod;
write(ans), enter;
return 0;
}