花园

题意

给一个长度为\(n\)的环,要求在这个环上填上\(0\)\(1\),使得这个环满足对于任意长度为\(m\)的区间,其中\(0\)的个数不超过\(k\)。请求出所有合法的填数的方案数

将环上的结点标号为\(1\)\(n\),两种方案不同当且仅当至少存在一个节点,两种方案在此处所填的数不同

\(n\leq 10^{15},k\leq m\leq 5,mod = 10^9+7\)


解法

见数据范围识矩乘优化DP

这里引用一下miracle大佬对矩乘优化DP题目的一些特点:

  • 一定存在一个线性递推式
  • 总有一个保持不变的转移矩阵
  • 由于矩乘的复杂度是\(O(n^3)\)的,所以转移矩阵的边长不能太大
  • 矩阵只需要保留可以继续转移的项

首先,我们观察到\(m\leq 5\),显然是可以进行状压的

我们能够得到以下的转移方程:\(f[i][j]=f[i-1][k]\times a[j][k]\),其中\(a[j][k]\)意味着状态\(j\)可以转移到状态\(k\)

\(a\)数组是很好处理的,这里不再赘述

这个转移方程与floyd很类似,可以用矩乘优化,转移矩阵即是我们处理出的\(a\)矩阵

至于环的情况如何处理?

有一个经典套路:枚举第一个状态\(s\)(复杂度\(O(2^m)\))后转移\(n\)次,将\(f[n+m][s]\)作为答案加入,这样就能保证首尾的状态均为\(s\),也就连接成了一个环

我们构建一个\((2^m\times 2^m)\)的初始矩阵,初始状态分别位于\(0\to 2^m\)之间,值为\(1\)

可以发现这就是一个单位矩阵

我们直接把转移矩阵自乘\(n\)次,统计对角线上元素的和作为答案即可


代码

#include <cstdio>
#include <cstring>

using namespace std;

template<typename _T> void read(_T& x) {
	int c = getchar(); x = 0;	
	while (c < '0' || c > '9')   c = getchar();
	while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
}

const int mod = 1e9 + 7;

long long n, m, k, sz;

void add(int &x, int y) { (x += y) > mod ? x -= mod : x; }
int mul(int x, int y) { return 1LL * x * y % mod; }

struct matrix {
	
	int a[50][50];	
	
	matrix() { memset(a, 0, sizeof a); }
	matrix operator = (const matrix& rhs) {
		for (int i = 0; i < sz; ++i)
			for (int j = 0; j < sz; ++j)  a[i][j] = rhs.a[i][j];	
		return *this;
	}
	friend matrix operator * (const matrix &lhs, const matrix &rhs) {
		matrix res;
		for (int i = 0; i < sz; ++i)
			for (int j = 0; j < sz; ++j)	
				for (int k = 0; k < sz; ++k)
					add(res.a[i][j], mul(lhs.a[i][k], rhs.a[k][j]));
		return res;
	}	
	matrix operator ^ (long long k) const {
		matrix res, t = *this;
		for (int i = 0; i < sz; ++i)  res.a[i][i] = 1;
		for (; k; t = t * t, k >>= 1)
			if (k & 1)  res = res * t;	
		return res;
	}
} mt;

int calc(int x) {
	int res = 0;
	while (x)  ++res, x -= x & -x;
	return res;
}

int main() {
	
	read(n), read(m), read(k);
	
	sz = 1 << m;
	
	for (int i = 0; i < sz; ++i) {
		if (calc(i) > k)  continue;
		mt.a[(i >> 1)][i] = 1;
		mt.a[(i >> 1) | (1 << m - 1)][i] = 1;
	}
	
	mt = mt ^ n;
	
	long long ans = 0;
	for (int i = 0; i < sz; ++i) 
		ans = (ans + mt.a[i][i]) % mod;
	
	printf("%lld\n", ans);
	
	return 0;
}
posted @ 2019-09-27 21:48  四季夏目天下第一  阅读(108)  评论(0编辑  收藏  举报