LOJ #6752. 「THUPC 2021 初赛」合法序列 题解
\(\mathcal{Solution}\)
考虑到 \(k\) 最大只有 \(4\),所以每一个子段限制的位置肯定是 \(\leq 2^k\) 的,也就是肯定只与前 \(2^k\) 个数相关,考虑爆搜这些数,只有 \(2^{2^k}\) 种可能,最大也就 \(2^{16}\),考虑怎样快速统计后面的方案数。
因为很多人过J所以J一定不会太难所以这个计数一定是dp
考虑 \(dp\),设 \(f_{i,j}\) 代表 \([i-k+1,i]\) 这段区间(也就是以第 \(i\) 个为结尾的子段)选的状态为 \(j\) 的方案数,这样就很好 \(dp\) 了,考虑 \(f_{i,j}\) 从 \(f_{i-1,j'}\) 转移,一定时满足 \(j'\) 的后 \(k-1\) 位和 \(j\) 的 前 \(k-1\) 位相等(也就是考虑把一个长度为 \(k\) 的区间向后推一个,这两个区间的重合部分),\(j'\) 的最前面一位此时填 \(0,1\) 都行,那么就可以直接转移,而满足能进行转移的 \(j\) 一定是要满足爆搜出来前 \(2^k\) 为第 \(j\) 位为 \(1\),那么提前存出来这些合法的 \(j\),转移直接遍历一遍即可,时间复杂度降的特别低。
大概是一个 \(\sum_{i=1}^{2^{2^k}}count\_bit(i) \times n\)。其中 \(count\_bit(x)\) 为 \((x)_2\) 中数位上为 \(1\) 的个数,写个程序简单算一下发现才仅仅 \(1e8\) 左右,足够通过此题。
上面那个复杂度可以写成 \(\mathcal{O}(2^{2^k}\times 2^k \times n)\)
\(\mathcal{Code}\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#define ll long long
#define re register
#define int long long
#define pb push_back
const int mod = 998244353;
inline int Add(int x, int y) { return (x + y - mod) < 0 ? (x + y) : (x + y - mod); }
inline int Max(int x, int y) { return x > y ? x : y; }
inline int Min(int x, int y) { return x < y ? x : y; }
inline int Abs(int x) { return x < 0 ? ~x + 1 : x; }
inline int read() {
int r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') w = 1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
r = (r << 3) + (r << 1) + (ch ^ 48);
ch = getchar();
}
return w ? ~r + 1 : r;
}
#undef int
const int N = 510;
int n, k, m, vecnt, a[N];
ll f[N][N];
std::vector<int>vec;
bool check(int x) {
int now = 0;
for(int i = x; i >= x - k + 1; --i)
now |= a[i] * (1 << (x - i));
return a[now] == 0;
}
ll solve(int x) {
vec.clear(); vecnt = 0;
for(int i = 0; i < m; ++i) a[i] = 0;
for(int i = 0; i < m; ++i)
if((1 << i) & x) {
vec.pb(m - i - 1);
a[m - i - 1] = 1;
++vecnt;
}
for(int i = 0; i < n; ++i)
for(int j = 0; j < vecnt; ++j)
f[i][j] = 0;
for(int i = k-1; i < m; ++i)
if(check(i))
return 0ll;
f[m-1][x & ((1 << k) - 1)] = 1;
for(int i = m; i < n; ++i)
for(int j = 0; j < vecnt; ++j)
f[i][vec[j]] = Add(f[i][vec[j]], f[i-1][(vec[j] >> 1) | (1 << (k-1))] + f[i-1][vec[j] >> 1]);
ll ans = 0;
for(int i = 0; i < vecnt; ++i)
ans = Add(ans, f[n-1][vec[i]]);
for(int i = 0; i < n; ++i)
for(int j = 0; j < vecnt; ++j)
f[i][vec[j]] = 0;
return ans;
}
int main() {
ll ans = 0;
n = read(); k = read(); m = (1 << k);
for(int i = 0; i < 1 << m; ++i)
ans = Add(ans, solve(i));
printf("%lld\n", ans);
return 0;
}