[HAOI2015]数字串拆分 题解

Description

link

Solution

首先 \(f\) 很好求,\(f[i]\) 就等于 \(f[i-1]+f[i-2]+...+f[i-m]\),看到 \(m\) 很小,所以矩乘优化成 \(m^3\log n\) 的复杂度,假设单位矩阵为 \(A\)。设 \(m_i\) 表示 \([f_{i-m+1},f_{i-m+2},...,f_i]\),那么:\(m_{i-1}\ast A=m_i\)

由于 \(g\) 的式子非常奇怪,所以不能直接搞。注意到 \(m_{i}=A^i\),所以 \(m_{i+j}=m_i\ast m_j\)

所以可以设 \(g_i\) 表示将字符串的前 \(i\) 个字符分割的的方案矩阵,那么:

\[g_i=\sum_{j=0}^{j-1}{g_j\ast D_{j+1,i}} \]

其中 \(D_{j+1,i}\) 表示从 \(j+1\)\(i\) 的字符的转移矩阵。

这样就可以 \(O(n^2\times m^3)\)

Code

代码
#include <bits/stdc++.h>

#ifdef ORZXKR
#include <debug.h>
#else
#define debug(...) 1
#endif

using namespace std;

namespace FASTIO {
char ibuf[1 << 21], *p1 = ibuf, *p2 = ibuf;
char getc() {
  return p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
}
template<class T> bool read(T &x) {
  x = 0; int f = 0; char ch = getc();
  while (ch < '0' || ch > '9') f |= ch == '-', ch = getc();
  while (ch >= '0' && ch <= '9') x = (x * 10) + (ch ^ 48), ch = getc();
  x = (f ? -x : x); return 1;
}
template<typename A, typename ...B> bool read(A &x,B &...y) { return read(x) && read(y...); }

char obuf[1 << 21], *o1 = obuf, *o2 = obuf + (1 << 21) - 1;
void flush() { fwrite(obuf, 1, o1 - obuf, stdout), o1 = obuf; }
void putc(char x) { *o1++ = x; if (o1 == o2) flush(); }
template<class T> void write(T x) {
  if (!x) putc('0');
  if (x < 0) x = -x, putc('-');
  char c[40]; int tot = 0;
  while (x) c[++tot] = x % 10, x /= 10;
  for (int i = tot; i; --i) putc(c[i] + '0');
}
void write(char x) { putc(x); }
template<typename A,typename ...B> void write(A x, B ...y) { write(x), write(y...); }
struct Flusher {
  ~Flusher() { flush(); }
} flusher;
} // namespace FASTIO
using FASTIO::read; using FASTIO::putc; using FASTIO::write;

const int kMod = 998244353;

struct matrix {
  int a[6][6];

  void clear() {
    memset(a, 0, sizeof(a));
  }
} b, mi[10];

int n, m;
int a[505];
char s[505];
matrix d[505][505], f[505];

int add(int x, int y) {
  return (x + y >= kMod) ? (x + y - kMod) : (x + y);
}

matrix mul(matrix a, matrix b) {
  static matrix c;
  c.clear();
  for (int k = 1; k <= m; ++k) {
    for (int i = 1; i <= m; ++i) {
      for (int j = 1; j <= m; ++j) {
        c.a[i][j] = (c.a[i][j] + 1ll * a.a[i][k] * b.a[k][j] % kMod) % kMod;
      }
    }
  }
  return c;
}

matrix add(matrix a, matrix b) {
  static matrix c;
  c.clear();
  for (int i = 1; i <= m; ++i) {
    for (int j = 1; j <= m; ++j) {
      c.a[i][j] = add(a.a[i][j], b.a[i][j]);
    }
  }
  return c;
}

matrix qpow(matrix bs, int idx = 10) {
  matrix ret = bs; --idx;
  for (; idx; idx >>= 1, bs = mul(bs, bs)) {
    if (idx & 1) ret = mul(ret, bs);
  }
  return ret;
}


int main() {
#ifdef ORZXKR
  freopen("in.txt", "r", stdin);
  freopen("out.txt", "w", stdout);
#endif
  scanf("%s", s + 1);
  scanf("%d", &m);
  n = strlen(s + 1);
  for (int i = 1; i <= n; ++i) {
    a[i] = s[i] - '0';
  }
  for (int i = 1; i < m; ++i) {
    b.a[i + 1][i] = 1;
  }
  for (int i = 1; i <= m; ++i) {
    b.a[i][m] = 1;
  }
  for (int i = 1; i <= m; ++i) {
    mi[0].a[i][i] = 1;
  }
  for (int i = 1; i <= 9; ++i) {
    mi[i] = mul(mi[i - 1], b);
  }
  for (int i = 1; i <= n; ++i) {
    matrix nw = mi[0];
    for (int j = i; j <= n; ++j) {
      d[i][j] = nw = mul(qpow(nw), mi[a[j]]);
    }
  }
  f[0] = mi[0];
  for (int i = 1; i <= n; ++i) {
    for (int j = 0; j < i; ++j) {
      f[i] = add(f[i], mul(f[j], d[j + 1][i]));
    }
  }
  int ans = 0;
  for (int i = 1; i <= m; ++i) {
    ans = add(ans, f[n].a[1][i]);
  }
  write(ans);
  return 0;
}
posted @ 2022-10-13 20:09  下蛋爷  阅读(44)  评论(0编辑  收藏  举报