【YBT2023寒假Day4 A】网格染色(DP)(矩阵乘法)(DFT)

网格染色

题目链接:YBT2023寒假Day4 A

题目大意

有一个 n*3 的网格,你可以选恰好 m 个格子染成黑色。
然后对于一个黑点,我们对它周围的 \(8\) 个点中的一些有限制不能是黑点,用一个矩阵给出。
问你有多少满足条件限制的条件。

思路

首先看到我们可以一行一行的 DP,每一行只跟前面的一行有关。
(预处理出限制,因为一行长度只有 \(3\),可以直接预处理出 \(8\) 个状态后面是否可以放那 \(8\) 个状态)

然后如果要用什么矩阵乘法的转移,会发现一个问题是你只能保留 \(n\) 这个,却无法维护 \(m\)
考虑 \(m\) 那一维是啥,发现是卷积。
于是你考虑把每一个状态它对应要乘的式子先都 DFT 了,然后对于式子的每一位都做一次矩阵乘法。
然后把得到的式子再 IDFT 回去,就是答案的式子了。

不过要注意的是因为你要全部 DFT 的结果乘起来,所以你的 \(limit\) 是不能变的。
那你就用 \(3*n\) 来求 \(limit\) 即可。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#define mo 998244353


using namespace std;

const int N = 1e5 + 100;
const int pN = N * 8;
int n, m, a[10][10], b[10][10], cnt[10], p[10][pN], f[10][pN];
int an[pN], G, Gv;

inline int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
inline int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
inline int mul(int x, int y) {return 1ll * x * y % mo;}

int ksm(int x, int y) {
	int re = 1;
	while (y) {
		if (y & 1) re = mul(re, x);
		x = mul(x, x); y >>= 1; 
	}
	return re;
}

struct matrix {
	int n, m, a[8][8];
	
	matrix() {
		n = 0; m = 0;
		for (int i = 0; i < 8; i++)
			for (int j = 0; j < 8; j++)
				a[i][j] = 0;
	}
}T;

matrix operator *(matrix x, matrix y) {
	matrix z; z.n = x.n; z.m = y.m;
	for (int i = 0; i < z.n; i++)
		for (int j = 0; j < z.m; j++)
			for (int k = 0; k < x.m; k++)
				z.a[i][j] = add(z.a[i][j], mul(x.a[i][k], y.a[k][j]));
	return z;
}

matrix jzksm(matrix x, int y) {
	matrix re; re.n = re.m = 8;
	for (int i = 0; i < 8; i++) re.a[i][i] = 1;
	while (y) {
		if (y & 1) re = re * x;
		x = x * x; y >>= 1;
	}
	return re;
}

void get_an(int limit, int l_size) {
	for (int i = 0; i < limit; i++)
		an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
}

void NTT(int *f, int op, int limit) {
	for (int i = 0; i < limit; i++)
		if (an[i] < i) swap(f[i], f[an[i]]);
	for (int mid = 1; mid < limit; mid <<= 1) {
		int Wn = ksm((op == 1) ? G : Gv, (mo - 1) / (mid << 1));
		for (int R = (mid << 1), j = 0; j < limit; j += R) {
			for (int w = 1, k = 0; k < mid; k++, w = mul(w, Wn)) {
				int x = f[j | k], y = mul(w, f[j | mid | k]);
				f[j | k] = add(x, y); f[j | mid | k] = dec(x, y);
			}
		}
	}
	if (op == -1) {
		int limv = ksm(limit, mo - 2);
		for (int i = 0; i < limit; i++) f[i] = mul(f[i], limv);
	}
}

int main() {
	freopen("final.in", "r", stdin);
	freopen("final.out", "w", stdout);
	
	scanf("%d %d", &n, &m); G = 3; Gv = ksm(G, mo - 2);
	for (int i = 1; i <= 3; i++)
		for (int j = 1; j <= 3; j++)
			scanf("%d", &a[j][i]);//麻了 n*m 和长乘宽又反了/lh 
	for (int i = 1; i <= 3; i++)
		for (int j = 1; j <= 3; j++)
			a[i][j] |= a[4 - i][4 - j];
	
	for (int i = 0; i < 8; i++) {
		int a1 = (i & 1), a2 = ((i >> 1) & 1), a3 = ((i >> 2) & 1);
		if (a[1][2] && ((a1 & a2) || (a2 & a3))) continue;
		for (int j = 0; j < 8; j++) {
			int b1 = (j & 1), b2 = ((j >> 1) & 1), b3 = ((j >> 2) & 1);
			if (a[1][2] && ((b1 & b2) || (b2 & b3))) continue;
			bool yes = 1;
			if (a1 && ((a[2][3] && b1) || (a[3][3] && b2))) yes = 0;
			if (a2 && ((a[1][3] && b1) || (a[2][3] && b2) || (a[3][3] && b3))) yes = 0;
			if (a3 && ((a[1][3] && b2) || (a[2][3] && b3))) yes = 0;
			if (yes) b[i][j] = 1;
		}
		cnt[i] = a1 + a2 + a3;
	}
	for (int i = 0; i < 8; i++)
		if (!b[0][i]) for (int j = 0; j < 8; j++) b[i][j] = 0;
	
	int limit = 1, l_size = 0; while (limit <= 3 * n) limit <<= 1, l_size++;
	get_an(limit, l_size); 
	for (int i = 0; i < 8; i++)
		if (b[0][i]) {
			p[i][cnt[i]] = 1;
			NTT(p[i], 1, limit);
		}
	
	T.n = T.m = 8;
	for (int t = 0; t < limit; t++) {
		memset(T.a, 0, sizeof(T.a));
		for (int i = 0; i < 8; i++)
			for (int j = 0; j < 8; j++)
				if (b[i][j]) T.a[i][j] = p[j][t];
		T = jzksm(T, n);
		for (int i = 0; i < 8; i++) f[i][t] = T.a[0][i];
	}
	for (int i = 0; i < 8; i++) NTT(f[i], -1, limit);
	int ans = 0; for (int i = 0; i < 8; i++) ans = add(ans, f[i][m]);
	printf("%d", ans);
	
	return 0;
}
posted @ 2023-02-02 08:38  あおいSakura  阅读(37)  评论(0编辑  收藏  举报