【题解】P3158 [CQOI2011]放棋子
兄弟们,我起了,一日之计在于晨呐。
题意
有一个 \(n\) 行 \(m\) 列的棋盘和 \(c\) 种颜色的棋子,每种棋子有 \(a_i\) 个。
要求不同颜色的棋子不能放在同一行或同一列,问放棋子的方案总数。
对 \(10^9 + 9\) 取模。
\(1 \leq n, m \leq 30, 1 \leq c \leq 10, \sum\limits a_i \leq \max(250, nm)\)
思路
容斥 dp.
可以观察到:对于任意合法方案,交换其中任意两行(或列)依然合法。
所以我们只关心被占用的行数和列数,不需要在意具体选中的行和列。
首先可以想到一个普通的 dp 做法:
设 \(f[i][j][k]\) 表示用前 \(k\) 种颜色占用 \(i\) 行 \(j\) 列的方案总数(注意不一定是 \(i \times j\) 的矩阵)
转移直接枚举最后一种颜色的贡献。
令 \(S(i, j, k)\) 表示用 \(k\) 个同颜色的棋子占用 \(i\) 行 \(j\) 列的方案数,那么转移可以写成:
\(f[i][j][k] = \sum\limits_{p = 1}^i \sum\limits_{q = 1}^j \dbinom{i}{p} \dbinom{j}{q} f[i - p][j - q][k - 1] \cdot S(p, q, a_k)\)
接下来考虑求出 \(S\).
直接做不好做,考虑容斥。
由于行和列实际上是等价的,所以容斥可以这样做:
占满 \(i\) 行 \(j\) 列 = 全部方案 - 至少不占 \(1\) 行/列 + 至少不占 \(2\) 行列 - ...
那么 \(S\) 的转移是:
\(S(i, j, k) = \sum\limits_{p = 0}^i \sum\limits_{q = 0}^j (-1)^{p + q} \dbinom{i}{p} \dbinom{j}{q} \dbinom{(i - p)(j - q)}{k}\)
最后一项的意义是在剩余的 \(i - p\) 行和 \(j - q\) 列中放置 \(k\) 个棋子的方案数,可以直接相乘的原因是平移后可以看成一个 \((i - p) \times (j - q)\) 的矩阵。
于是最终的答案是:
\(\sum\limits_{i = 1}^n \sum\limits_{j = 1}^m \dbinom{n}{i} \dbinom{m}{j} f[i][j][c]\)
时间复杂度 \(O(n^3 m^3 + n^2 m^2 c)\)
代码
#include <cstdio>
using namespace std;
#define int long long
const int maxn = 35;
const int maxm = 35;
const int maxc = 15;
const int sz = 905;
const int mod = 1e9 + 9;
inline int min(const int &a, const int &b) { return (a <= b ? a : b); }
inline int max(const int &a, const int &b) { return (a >= b ? a : b); }
int n, m, c;
int a[maxc];
int fac[sz], inv[sz];
int f[maxn][maxm][maxc], s[maxn][maxm][sz];
int C(int n, int m)
{
if ((n < 0) || (m < 0) || (n < m)) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
void init(int lim)
{
fac[0] = inv[0] = inv[1] = 1;
for (int i = 1; i <= lim; i++) fac[i] = fac[i - 1] * i % mod;
for (int i = 2; i <= lim; i++) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
for (int i = 2; i <= lim; i++) inv[i] = inv[i] * inv[i - 1] % mod;
}
int S(int i, int j, int k)
{
if ((k < max(i, j)) || (k > i * j)) return 0;
if (s[i][j][k]) return s[i][j][k];
for (int p = 0; p <= i; p++)
for (int q = 0; q <= j; q++)
{
int res = C(i, p) * C(j, q) % mod * C((i - p) * (j - q), k) % mod;
if ((p + q) & 1) s[i][j][k] = (s[i][j][k] - res + mod) % mod;
else s[i][j][k] = (s[i][j][k] + res) % mod;
}
return s[i][j][k];
}
signed main()
{
scanf("%lld%lld%lld", &n, &m, &c);
init(n * m);
for (int i = 1; i <= c; i++) scanf("%lld", &a[i]);
f[0][0][0] = 1;
for (int k = 1; k <= c; k++)
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
for (int p = 1; p <= i; p++)
for (int q = 1; q <= j; q++)
f[i][j][k] = (f[i][j][k] + C(i, p) * C(j, q) % mod * f[i - p][j - q][k - 1] % mod * S(p, q, a[k]) % mod) % mod;
int ans = 0;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
ans = (ans + C(n, i) * C(m, j) % mod * f[i][j][c] % mod) % mod;
ans = (ans + mod) % mod;
printf("%lld\n", ans);
return 0;
}