[AHOI2009]中国象棋

洛谷

简化题意

在一个 \(n \times m\) 的棋盘上最多可以放置多少个炮,试任意的两个炮不能相互打到。


思路

感觉这也不像是一个状压DP,但是标签上就是这么写的..

可以发现,让任意的两个炮不能相互打到就是,任意的一行一列都不能有两个炮。

\(f_{i,j,k}\) 表示前 i 行中 j 列有一个炮,k 列有两个炮的最大方案数。

可以发现,可以由一下状态转移而来。

  • \(f_{i-1,j, k}\),可以一个都不放,那就是上一个的方案数。
  • \(f_{i-1, j-1,k}\),从一个炮都没有放的列转移而来,因为有 \(m - (j - 1) - k\) 个没有放任何炮的列,所以最后要乘上这个系数。
  • \(f_{i-1,j+1,k-1}\),这个表示从 j+1 个放了一个炮的,在这些列中任意放上一个,让放一个的变成放两个的,因为那些列可以随便取,所以要乘以 \(j+1\)
  • \(f_{i-1,j - 2, k}\), 当前这一行放两个,都放到没有放过的地方,这个时候没有放过的地方有 \(m - (j - 2) - k\) 个,从这些中任取两个,就是 \({m-(j - 2) - k}\choose 2\),乘上这个系数。
  • \(f_{i - 1, j, k-1}\),当前这一行放两个,放到同一列中,让那一列变成两个炮,因为当前没放过的有 \(m - j -(k -1)\)
  • \(f_{i-1,j+2, k - 2}\),从原本有一个炮的列中任意选两个,各放上一个炮,\({j+2} \choose 2\),乘上这个系数。

code

#include <cmath> 
#include <cstdio> 
#include <cstring> 
#include <iostream> 
#include <algorithm> 
#define ll long long 
#define N 100010 
#define M 110

using namespace std;
const int mod = 9999973;
const int inf = 0x3f3f3f3f;
int n, m;
ll dp[M][M][M], jc[M];

int read() {
	int s = 0, f = 0; char ch = getchar();
	while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
	while (isdigit(ch)) s = s * 10 + (ch ^ 48), ch = getchar();
	return f ? -s : s;
}

ll q_pow(ll a, ll b) {
	ll ans = 1;
	while (b) {
		if (b & 1) ans = (1ll * ans * a) % mod;
		a = (1ll * a * a) % mod;
		b >>= 1;
	}
	return ans;
}

ll C(int x) {
	return x * (x - 1) % mod * q_pow(2, mod - 2) % mod;
}

int main() {
	n = read(), m = read();
	dp[0][0][0] = 1, jc[1] = 1;
	for (int i = 2; i <= 100; i++) jc[i] = (1ll * jc[i - 1] * i) % mod;
	for (int i = 1; i <= n; i++) 
    for (int j = 0; j <= m; j++) 
      for (int k = 0; k <= m - j; k++) {
				dp[i][j][k] = dp[i - 1][j][k];
				if (j >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j - 1][k] * (m - j - k + 1) % mod) % mod;
				if (k >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j + 1][k - 1] * (j + 1) % mod) % mod;
				if (k >= 1) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j][k - 1] * j * (m - j - k + 1) % mod) % mod;
				if (j >= 2) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j - 2][k] * C(m - j - k + 2) % mod) % mod;
				if (k >= 2) dp[i][j][k] = (dp[i][j][k] + dp[i - 1][j + 2][k - 2] * (j + 2) % mod * (j + 1) % mod * q_pow(2, mod - 2) % mod) % mod;
			}
	ll ans = 0;
	for (int i = 0; i <= m; i++) for (int j = 0; j <= m; j++) ans = (ans + dp[n][i][j]) % mod;
	cout << ans;
}
posted @ 2020-11-12 08:22  Kersen  阅读(93)  评论(0)    收藏  举报