ABC256G

首先可以想到枚举每条边上的白色石子个数,记为 \(k\)

则最终答案为 \(\sum\limits_{k=0}^{d+1}f(k)\)\(f(x)\) 表示每条边的石子个数为 \(x\) 时的答案。

那么可以想到一个暴力的 dp 状态,设 \(f_{i,j,k}\) 表示考虑了前 \(i\) 条边,最开始的点的颜色是 \(j\),这条边的终点的颜色是 \(k\) 时的方案数,其中 \(1\) 为白色,\(0\) 为黑色。
则有如下转移:

\[f_{i+1,0,0}=f_{i,0,0}\times \binom{D-1}{K}+f_{i,0,1}\times \binom{D-1}{K-1} \]

\[f_{i+1,0,1}=f_{i,0,0}\times \binom{D-1}{K-1}+f_{i,0,1}\times \binom{D-1}{K-2} \]

\[f_{i+1,1,0}=f_{i,1,0}\times \binom{D-1}{K}+f_{i,1,1}\times \binom{D-1}{K-1} \]

\[f_{i+1,1,1}=f_{i,1,0}\times \binom{D-1}{K-1}+f_{i,1,1}\times \binom{D-1}{K-2} \]

由于原图是个环,所以最后只有 \(f_{n,0,0}\)\(f_{n,1,1}\) 有贡献。

那么发现转移和 \(i\) 无关,于是写成矩阵形式,不难发现其实就是 \(\begin{bmatrix} A_{0,0} & A_{0,1}\\ A_{1,0} & A_{1,1} \end{bmatrix}^{n}\),其中 \(A_{i,j}=\dbinom{D-1}{K-i-j}\)

直接矩阵快速幂优化即可。

Code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 10005, mod = 998244353;
ll n; int D;
int fac[N], inv[N];
struct mat {
	int a[2][2];
	
	mat operator * (const mat &x) const {
		mat res; memset(res.a, 0, sizeof res.a);
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)	
				for (int k = 0; k < 2; ++k)
					res.a[i][j] = (res.a[i][j] + 1ll * a[i][k] * x.a[k][j] % mod) % mod;
		return res;
	}
} f;

int qpow(int x, int y) {
	int res = 1;
	while (y) {
		if (y & 1) res = 1ll * res * x % mod;
		x = 1ll * x * x % mod;
		y >>= 1;
	}
	return res;
}

void init(int maxn) {
	fac[0] = 1;
	for (int i = 1; i <= maxn; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
	inv[maxn] = qpow(fac[maxn], mod - 2);
	for (int i = maxn - 1; ~i; --i) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}

int C(int n, int m) {
	if (n < 0 || m < 0 || n < m) return 0;
	return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod;
}

mat qpow(mat x, ll y) {
	mat res; res = f;
	while (y) {
		if (y & 1) res = res * x;
		x = x * x;
		y >>= 1;
	}
	return res;
}

int main() {
	init(10000);
	scanf("%lld%d", &n, &D);
	memset(f.a, 0, sizeof f.a);
	f.a[0][0] = f.a[1][1] = 1;
	int ans = 0;
	for (int k = 0; k <= D + 1; ++k) {
		mat tmp;
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				tmp.a[i][j] = C(D - 1, k - i - j);
		tmp = qpow(tmp, n);
		ans = (ans + tmp.a[0][0]) % mod;
		ans = (ans + tmp.a[1][1]) % mod;
	}
	printf("%d", ans);
	return 0;
}
posted @ 2022-11-08 20:55  Kobe303  阅读(23)  评论(0编辑  收藏  举报