ABC256G
首先可以想到枚举每条边上的白色石子个数,记为 \(k\)。
则最终答案为 \(\sum\limits_{k=0}^{d+1}f(k)\),\(f(x)\) 表示每条边的石子个数为 \(x\) 时的答案。
那么可以想到一个暴力的 dp 状态,设 \(f_{i,j,k}\) 表示考虑了前 \(i\) 条边,最开始的点的颜色是 \(j\),这条边的终点的颜色是 \(k\) 时的方案数,其中 \(1\) 为白色,\(0\) 为黑色。
则有如下转移:
\[f_{i+1,0,0}=f_{i,0,0}\times \binom{D-1}{K}+f_{i,0,1}\times \binom{D-1}{K-1}
\]
\[f_{i+1,0,1}=f_{i,0,0}\times \binom{D-1}{K-1}+f_{i,0,1}\times \binom{D-1}{K-2}
\]
\[f_{i+1,1,0}=f_{i,1,0}\times \binom{D-1}{K}+f_{i,1,1}\times \binom{D-1}{K-1}
\]
\[f_{i+1,1,1}=f_{i,1,0}\times \binom{D-1}{K-1}+f_{i,1,1}\times \binom{D-1}{K-2}
\]
由于原图是个环,所以最后只有 \(f_{n,0,0}\) 和 \(f_{n,1,1}\) 有贡献。
那么发现转移和 \(i\) 无关,于是写成矩阵形式,不难发现其实就是 \(\begin{bmatrix} A_{0,0} & A_{0,1}\\ A_{1,0} & A_{1,1} \end{bmatrix}^{n}\),其中 \(A_{i,j}=\dbinom{D-1}{K-i-j}\)。
直接矩阵快速幂优化即可。
Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 10005, mod = 998244353;
ll n; int D;
int fac[N], inv[N];
struct mat {
int a[2][2];
mat operator * (const mat &x) const {
mat res; memset(res.a, 0, sizeof res.a);
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][j] = (res.a[i][j] + 1ll * a[i][k] * x.a[k][j] % mod) % mod;
return res;
}
} f;
int qpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return res;
}
void init(int maxn) {
fac[0] = 1;
for (int i = 1; i <= maxn; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
inv[maxn] = qpow(fac[maxn], mod - 2);
for (int i = maxn - 1; ~i; --i) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}
int C(int n, int m) {
if (n < 0 || m < 0 || n < m) return 0;
return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
mat qpow(mat x, ll y) {
mat res; res = f;
while (y) {
if (y & 1) res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
int main() {
init(10000);
scanf("%lld%d", &n, &D);
memset(f.a, 0, sizeof f.a);
f.a[0][0] = f.a[1][1] = 1;
int ans = 0;
for (int k = 0; k <= D + 1; ++k) {
mat tmp;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
tmp.a[i][j] = C(D - 1, k - i - j);
tmp = qpow(tmp, n);
ans = (ans + tmp.a[0][0]) % mod;
ans = (ans + tmp.a[1][1]) % mod;
}
printf("%d", ans);
return 0;
}