花园
题意
给一个长度为\(n\)的环,要求在这个环上填上\(0\)或\(1\),使得这个环满足对于任意长度为\(m\)的区间,其中\(0\)的个数不超过\(k\)。请求出所有合法的填数的方案数
将环上的结点标号为\(1\)到\(n\),两种方案不同当且仅当至少存在一个节点,两种方案在此处所填的数不同
\(n\leq 10^{15},k\leq m\leq 5,mod = 10^9+7\)
解法
见数据范围识矩乘优化DP
这里引用一下miracle大佬对矩乘优化DP题目的一些特点:
- 一定存在一个线性递推式
- 总有一个保持不变的转移矩阵
- 由于矩乘的复杂度是\(O(n^3)\)的,所以转移矩阵的边长不能太大
- 矩阵只需要保留可以继续转移的项
首先,我们观察到\(m\leq 5\),显然是可以进行状压的
我们能够得到以下的转移方程:\(f[i][j]=f[i-1][k]\times a[j][k]\),其中\(a[j][k]\)意味着状态\(j\)可以转移到状态\(k\)
\(a\)数组是很好处理的,这里不再赘述
这个转移方程与floyd很类似,可以用矩乘优化,转移矩阵即是我们处理出的\(a\)矩阵
至于环的情况如何处理?
有一个经典套路:枚举第一个状态\(s\)(复杂度\(O(2^m)\))后转移\(n\)次,将\(f[n+m][s]\)作为答案加入,这样就能保证首尾的状态均为\(s\),也就连接成了一个环
我们构建一个\((2^m\times 2^m)\)的初始矩阵,初始状态分别位于\(0\to 2^m\)之间,值为\(1\)
可以发现这就是一个单位矩阵
我们直接把转移矩阵自乘\(n\)次,统计对角线上元素的和作为答案即可
代码
#include <cstdio>
#include <cstring>
using namespace std;
template<typename _T> void read(_T& x) {
int c = getchar(); x = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
}
const int mod = 1e9 + 7;
long long n, m, k, sz;
void add(int &x, int y) { (x += y) > mod ? x -= mod : x; }
int mul(int x, int y) { return 1LL * x * y % mod; }
struct matrix {
int a[50][50];
matrix() { memset(a, 0, sizeof a); }
matrix operator = (const matrix& rhs) {
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j) a[i][j] = rhs.a[i][j];
return *this;
}
friend matrix operator * (const matrix &lhs, const matrix &rhs) {
matrix res;
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
for (int k = 0; k < sz; ++k)
add(res.a[i][j], mul(lhs.a[i][k], rhs.a[k][j]));
return res;
}
matrix operator ^ (long long k) const {
matrix res, t = *this;
for (int i = 0; i < sz; ++i) res.a[i][i] = 1;
for (; k; t = t * t, k >>= 1)
if (k & 1) res = res * t;
return res;
}
} mt;
int calc(int x) {
int res = 0;
while (x) ++res, x -= x & -x;
return res;
}
int main() {
read(n), read(m), read(k);
sz = 1 << m;
for (int i = 0; i < sz; ++i) {
if (calc(i) > k) continue;
mt.a[(i >> 1)][i] = 1;
mt.a[(i >> 1) | (1 << m - 1)][i] = 1;
}
mt = mt ^ n;
long long ans = 0;
for (int i = 0; i < sz; ++i)
ans = (ans + mt.a[i][i]) % mod;
printf("%lld\n", ans);
return 0;
}