HDU 6143 Killer Names (组合数学+DP)

Description

字母表的长度为\(m\),用表中的字母构造长度为\(2n\)的字符串,要求同一种字母能同时出现在前\(n\)个字符中和后\(n\)个字符中。输出方案数,结果模\(10^9+7\)

Input

第一行给出用例组数\(T\),每组用例给出两个整数\(n\)\(m\)\(1 \leqslant n,m \leqslant 2000\)

Output

对于每组用例,输出方案数。

Sample Input

2
3 2
2 3

Sample Output

2 
18

Solution

字母表中的每中字符有三种情况,在前\(n\)个字符中出现,在后\(n\)个字符中出现和不出现。设有\(p\)中字符在前\(n\)个字符中出现,\(q\)种字符在后\(n\)个字符中出现,则这种分配方案数量为

\[\binom{m}{p,q,m-p-q}=\frac{m!}{p! \cdot q! \cdot (m-p-q)!} \]

现在只需要解决,\(n\)个位置,\(m\)种字符,每种字符至少用一次的方案数。
\(dp[i][j]\)表示前\(i\)个位置,\(j\)种字符每种至少用一种的方案数。
考虑前\(i-1\)个位置:

  1. 如果只用了\(j-1\)种字符,那么第\(i\)个位置就必须用最后一种字符,但最后一种字符具体是那个是不确定的,因此有\(j\)种情况。
  2. 如果j种字符全都出现过,那么第\(i\)个位置放任意一个字符都可以,同样有\(j\)种情况。
    综上,

\[dp[i][j]=(dp[i-1][j-1]+dp[i-1][j]) \cdot j \]

最终答案为

\[ans = \sum_{p,q \leqslant n,p+q \leqslant m}{\binom{m}{p,q,m-p-q} \cdot dp[n][p] \cdot dp[n][q]} \]

简单化简整理得

\[ans = m! \cdot \sum_{p,q \leqslant n,p+q \leqslant m}{dp[n][p] \cdot dp[n][q] \cdot inv(p!) \cdot inv(q!) \cdot inv((m-p-q)!)} \]

其中\(inv(i)\)表示\(i\)的逆元。\(O(n^2)\)的时间得到\(dp\)数组,\(O(n)\)的时间预处理\(0\)\(2000\)的阶乘,再用\(O(n)\)的时间得到\(0\)\(2000\)阶乘的逆元,枚举每个\(p\)\(q\),累加即为答案。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const ll mod = 1e9 + 7;
const int N = 2e3 + 10;

ll power(ll a, ll b, ll mod)
{
	ll ans = 1;
	while (b)
	{
		if (b & 1) ans = ans * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return ans;
}

ll dp[N][N], fac[N], ifac[N];

void init(int n)
{
	fac[0] = 1;
	for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
	ifac[0] = 1;
	for (int i = 1; i <= n; i++) ifac[i] = power(fac[i], mod - 2, mod);
	for (int i = 1; i <= n; i++) dp[i][1] = 1, dp[i][i] = fac[i];
	for (int i = 3; i <= n; i++)
		for (int j = 2; j < i; j++)
			dp[i][j] = (dp[i - 1][j - 1] + dp[i - 1][j]) * j % mod;
}

int main()
{
	init(2000);
	int T;
	scanf("%d", &T);
	while (T--)
	{
		int n, m;
		scanf("%d%d", &n, &m);
		ll ans = 0;
		for (int p = 1; p <= n && p <= m; p++)
			for (int q = 1; q <= n && p + q <= m; q++)
			{
				ll s = dp[n][p] % mod * dp[n][q] % mod;
				s = s * ifac[p] % mod * ifac[q] % mod * ifac[m - p - q] % mod;
				ans = (ans + s) % mod;
			}
		ans = ans * fac[m] % mod;
		printf("%lld\n", ans);
	}
	return 0;
}

http://acm.hdu.edu.cn/showproblem.php?pid=6143

posted @ 2017-08-17 21:46  达达Mr_X  阅读(170)  评论(0编辑  收藏  举报