[HAOI2015]数字串拆分 题解
Description
Solution
首先 \(f\) 很好求,\(f[i]\) 就等于 \(f[i-1]+f[i-2]+...+f[i-m]\),看到 \(m\) 很小,所以矩乘优化成 \(m^3\log n\) 的复杂度,假设单位矩阵为 \(A\)。设 \(m_i\) 表示 \([f_{i-m+1},f_{i-m+2},...,f_i]\),那么:\(m_{i-1}\ast A=m_i\)。
由于 \(g\) 的式子非常奇怪,所以不能直接搞。注意到 \(m_{i}=A^i\),所以 \(m_{i+j}=m_i\ast m_j\)。
所以可以设 \(g_i\) 表示将字符串的前 \(i\) 个字符分割的的方案矩阵,那么:
\[g_i=\sum_{j=0}^{j-1}{g_j\ast D_{j+1,i}}
\]
其中 \(D_{j+1,i}\) 表示从 \(j+1\) 到 \(i\) 的字符的转移矩阵。
这样就可以 \(O(n^2\times m^3)\)
Code
代码
#include <bits/stdc++.h>
#ifdef ORZXKR
#include <debug.h>
#else
#define debug(...) 1
#endif
using namespace std;
namespace FASTIO {
char ibuf[1 << 21], *p1 = ibuf, *p2 = ibuf;
char getc() {
return p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
}
template<class T> bool read(T &x) {
x = 0; int f = 0; char ch = getc();
while (ch < '0' || ch > '9') f |= ch == '-', ch = getc();
while (ch >= '0' && ch <= '9') x = (x * 10) + (ch ^ 48), ch = getc();
x = (f ? -x : x); return 1;
}
template<typename A, typename ...B> bool read(A &x,B &...y) { return read(x) && read(y...); }
char obuf[1 << 21], *o1 = obuf, *o2 = obuf + (1 << 21) - 1;
void flush() { fwrite(obuf, 1, o1 - obuf, stdout), o1 = obuf; }
void putc(char x) { *o1++ = x; if (o1 == o2) flush(); }
template<class T> void write(T x) {
if (!x) putc('0');
if (x < 0) x = -x, putc('-');
char c[40]; int tot = 0;
while (x) c[++tot] = x % 10, x /= 10;
for (int i = tot; i; --i) putc(c[i] + '0');
}
void write(char x) { putc(x); }
template<typename A,typename ...B> void write(A x, B ...y) { write(x), write(y...); }
struct Flusher {
~Flusher() { flush(); }
} flusher;
} // namespace FASTIO
using FASTIO::read; using FASTIO::putc; using FASTIO::write;
const int kMod = 998244353;
struct matrix {
int a[6][6];
void clear() {
memset(a, 0, sizeof(a));
}
} b, mi[10];
int n, m;
int a[505];
char s[505];
matrix d[505][505], f[505];
int add(int x, int y) {
return (x + y >= kMod) ? (x + y - kMod) : (x + y);
}
matrix mul(matrix a, matrix b) {
static matrix c;
c.clear();
for (int k = 1; k <= m; ++k) {
for (int i = 1; i <= m; ++i) {
for (int j = 1; j <= m; ++j) {
c.a[i][j] = (c.a[i][j] + 1ll * a.a[i][k] * b.a[k][j] % kMod) % kMod;
}
}
}
return c;
}
matrix add(matrix a, matrix b) {
static matrix c;
c.clear();
for (int i = 1; i <= m; ++i) {
for (int j = 1; j <= m; ++j) {
c.a[i][j] = add(a.a[i][j], b.a[i][j]);
}
}
return c;
}
matrix qpow(matrix bs, int idx = 10) {
matrix ret = bs; --idx;
for (; idx; idx >>= 1, bs = mul(bs, bs)) {
if (idx & 1) ret = mul(ret, bs);
}
return ret;
}
int main() {
#ifdef ORZXKR
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
scanf("%s", s + 1);
scanf("%d", &m);
n = strlen(s + 1);
for (int i = 1; i <= n; ++i) {
a[i] = s[i] - '0';
}
for (int i = 1; i < m; ++i) {
b.a[i + 1][i] = 1;
}
for (int i = 1; i <= m; ++i) {
b.a[i][m] = 1;
}
for (int i = 1; i <= m; ++i) {
mi[0].a[i][i] = 1;
}
for (int i = 1; i <= 9; ++i) {
mi[i] = mul(mi[i - 1], b);
}
for (int i = 1; i <= n; ++i) {
matrix nw = mi[0];
for (int j = i; j <= n; ++j) {
d[i][j] = nw = mul(qpow(nw), mi[a[j]]);
}
}
f[0] = mi[0];
for (int i = 1; i <= n; ++i) {
for (int j = 0; j < i; ++j) {
f[i] = add(f[i], mul(f[j], d[j + 1][i]));
}
}
int ans = 0;
for (int i = 1; i <= m; ++i) {
ans = add(ans, f[n].a[1][i]);
}
write(ans);
return 0;
}