HDU 6270 Marriage (2017 CCPC 杭州赛区 G题,生成函数 + 容斥 + 分治NTT)

题目链接  2017 CCPC Hangzhou Problem G

题意描述很清晰。

考虑每个家庭有且仅有$k$对近亲的方案数:

$C(a, k) * C(b, k) * k!$

那么如果在第$1$个家庭里面选出$k_{1}$对近亲,在第$2$个家庭里面选出$k_{2}$对近亲......在第$n$个家庭里面选出$k_{n}$对近亲,

剩下那些人自由组合的话,那么最后这种方案至少会有$∑k$对近亲。

说是至少,因为同一个家庭里面没被强行选择的男女还是可能被组到了一起。

那么考虑如何求至少有$k$对近亲的方案数。

构造$n$个多项式,对于每个家庭,这个多项式为

$c_{0} + c_{1}x + c_{2}x^{2} + c_{3}x^{3} + c_{4}x^{4} + ... + c_{p}x^{p}$, $p = min(a, b)$

其中$c_{i}$这个系数为在这个家庭里面选出$i$对近亲的方案数。

那么只要把这$n$个多项式乘起来,得到的结果里面$x^{k}$的系数就是至少有$k$对近亲的方案数。

把$n$个多项式求出来用分治NTT即可,我用了启发式合并。

因为是至少,所以还要考虑容斥。

最后的答案就是$(-1)^{k}a_{k} * (m - k)!$,$m$为总人数

时间复杂度$O(nlog^{2}n)$

 

#include <bits/stdc++.h>

using namespace std;

#define rep(i, a, b)    for (int i(a); i <= (b); ++i)
#define dec(i, a, b)    for (int i(a); i >= (b); --i)
#define MP		make_pair
#define fi		first
#define se		second


typedef long long LL;

const int  N = 1e5 + 10;
const LL mod = 998244353;
const LL g   = 3;

vector <LL> v[N << 1];
LL x1[N << 1], x2[N << 1];
LL fac[N];
LL ans, flag;
int T;
int n, all, cnt;
int sz;

inline LL Pow(LL a, LL b, LL mod){
	LL ret(1);
	for (; b; b >>= 1, (a *= a) %= mod) if (b & 1) (ret *= a) %= mod;
	return ret;
}

inline LL C(LL n, LL m){ return m > n ? 0 : fac[n] * Pow(fac[m] * fac[n - m] % mod, mod - 2, mod) % mod; }

struct cmp{
	bool operator ()(int a, int b){
		return v[a].size() > v[b].size();
	}
};

priority_queue <LL, vector <LL>, cmp> q;

void change(LL y[], int len){
	for (int i = 1, j = len / 2; i < len - 1; i++){
		if (i < j) swap(y[i], y[j]);
		int k = len / 2;
		while (j >= k){
			j -= k;
			k /= 2;
		}
		if (j < k) j += k;
	}
}

void ntt(LL y[], int len, int on){
	change(y, len);
	for (int h = 2; h <= len; h <<= 1){
		LL wn = Pow(g, (mod - 1) / h, mod);
		if (on == -1) wn = Pow(wn, mod - 2, mod);
		for (int j = 0; j < len; j += h){
			LL w = 1ll;
			for (int k = j; k < j + h / 2; k++){
				LL u = y[k];
				LL t = w * y[k + h / 2] % mod;
				y[k] = (u + t) % mod;
				y[k + h / 2] = (u - t + mod) % mod;
				w = w * wn % mod;
			}
		}
	}

	if (on == -1){
		LL t = Pow(len, mod - 2, mod);
		rep(i, 0, len - 1) y[i] = y[i] * t % mod;
	}
}

void mul(vector <LL> &a, vector <LL> &b, vector <LL> &c){
	int len = 1;
	int sz1 = a.size(), sz2 = b.size();

	while (len <= sz1 + sz2 - 1) len <<= 1;

	rep(i, 0, sz1 - 1) x1[i] = a[i];
	rep(i, sz1, len)   x1[i] = 0;

	rep(i, 0, sz2 - 1) x2[i] = b[i];
	rep(i, sz2, len)   x2[i] = 0;

	ntt(x1, len, 1);
	ntt(x2, len, 1);

	rep(i, 0, len - 1) x1[i] = x1[i] * x2[i];

	ntt(x1, len, -1);

	vector <LL>().swap(c);
	rep(i, 0, sz1 + sz2 - 2) c.push_back(x1[i]);
}


int main(){

	fac[0] = 1;
	rep(i, 1, 1e5 + 3) fac[i] = fac[i - 1] * i % mod;

	scanf("%d", &T);

	while (T--){
		scanf("%d", &n);
		rep(i, 0, n + 1) vector <LL>().swap(v[i]);

		while (!q.empty()) q.pop();

		all = 0;

		rep(i, 1, n){
			int x, y;
			scanf("%d%d", &x, &y);
			v[i].resize(min(x, y) + 1);
			rep(k, 0, min(x, y)) v[i][k] = C(x, k) * C(y, k) % mod * fac[k] % mod;
			q.push(i);
			all += x;
		}

		cnt = n;
		rep(i, 1, n - 1){
			int x = q.top(); q.pop();
			int y = q.top(); q.pop();

			mul(v[x], v[y], v[++cnt]);

			vector <LL>().swap(v[x]);
			vector <LL>().swap(v[y]);

			q.push(cnt);
		}

		ans  = 0;
		flag = 1;
		sz   = (int)v[cnt].size();
		rep(i, 0, sz - 1){
			ans = ans + flag * fac[all - i] % mod * v[cnt][i] % mod;
			ans = (ans + mod) % mod;
			flag = -flag;
		}

		printf("%lld\n", ans);
	}

	return 0;
}

 

posted @ 2018-06-20 10:43  cxhscst2  阅读(1067)  评论(0编辑  收藏  举报