[AHOI2009]中国象棋
简化题意
在一个 \(n \times m\) 的棋盘上最多可以放置多少个炮,试任意的两个炮不能相互打到。
思路
感觉这也不像是一个状压DP,但是标签上就是这么写的..
可以发现,让任意的两个炮不能相互打到就是,任意的一行一列都不能有两个炮。
让 \(f_{i,j,k}\) 表示前 i 行中 j 列有一个炮,k 列有两个炮的最大方案数。
可以发现,可以由一下状态转移而来。
- \(f_{i-1,j, k}\),可以一个都不放,那就是上一个的方案数。
- \(f_{i-1, j-1,k}\),从一个炮都没有放的列转移而来,因为有 \(m - (j - 1) - k\) 个没有放任何炮的列,所以最后要乘上这个系数。
- \(f_{i-1,j+1,k-1}\),这个表示从 j+1 个放了一个炮的,在这些列中任意放上一个,让放一个的变成放两个的,因为那些列可以随便取,所以要乘以 \(j+1\)。
- \(f_{i-1,j - 2, k}\), 当前这一行放两个,都放到没有放过的地方,这个时候没有放过的地方有 \(m - (j - 2) - k\) 个,从这些中任取两个,就是 \({m-(j - 2) - k}\choose 2\),乘上这个系数。
- \(f_{i - 1, j, k-1}\),当前这一行放两个,放到同一列中,让那一列变成两个炮,因为当前没放过的有 \(m - j -(k -1)\)。
- \(f_{i-1,j+2, k - 2}\),从原本有一个炮的列中任意选两个,各放上一个炮,\({j+2} \choose 2\),乘上这个系数。
code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
#define N 100010
#define M 110
using namespace std;
const int mod = 9999973;
const int inf = 0x3f3f3f3f;
int n, m;
ll dp[M][M][M], jc[M];
int read() {
int s = 0, f = 0; char ch = getchar();
while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while (isdigit(ch)) s = s * 10 + (ch ^ 48), ch = getchar();
return f ? -s : s;
}
ll q_pow(ll a, ll b) {
ll ans = 1;
while (b) {
if (b & 1) ans = (1ll * ans * a) % mod;
a = (1ll * a * a) % mod;
b >>= 1;
}
return ans;
}
ll C(int x) {
return x * (x - 1) % mod * q_pow(2, mod - 2) % mod;
}
int main() {
n = read(), m = read();
dp[0][0][0] = 1, jc[1] = 1;
for (int i = 2; i <= 100; i++) jc[i] = (1ll * jc[i - 1] * i) % mod;
for (int i = 1; i <= n; i++)
for (int j = 0; j <= m; j++)
for (int k = 0; k <= m - j; k++) {
dp[i][j][k] = dp[i - 1][j][k];
if (j >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j - 1][k] * (m - j - k + 1) % mod) % mod;
if (k >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j + 1][k - 1] * (j + 1) % mod) % mod;
if (k >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j][k - 1] * j * (m - j - k + 1) % mod) % mod;
if (j >= 2) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j - 2][k] * C(m - j - k + 2) % mod) % mod;
if (k >= 2) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j + 2][k - 2] * (j + 2) % mod * (j + 1) % mod * q_pow(2, mod - 2) % mod) % mod;
}
ll ans = 0;
for (int i = 0; i <= m; i++) for (int j = 0; j <= m; j++) ans = (ans + dp[n][i][j]) % mod;
cout << ans;
}