2017 CCPC 杭州 HDU 6270 Marriage (NTT,容斥)
题目:传送门
题意
有 n 个家庭,每个家庭有 ai 个男孩和 bi 个女孩,n 个家庭总的男孩等于总的女孩。对于来自 i 家庭的男孩他只能和不来自 i 家庭的女孩结婚,也就是来自同个家庭的男孩女孩不能结婚。问有多少种方案,使得这 这些男孩女孩都能成功结婚。
思路
参考博客:戳
对于一个有 x 个男孩和 y 个女孩的家庭来说,有且仅有 k 对来自这个家庭的男孩女孩结婚(近亲结婚)的方案数是:
C(x, k) * C(y,k) * k!
那么如果在第一个家庭选 k1 对近亲结婚,第二个家庭 k2 对......第 n 个家庭 kn 对,剩下的自由组合,最后这种方案至少有 k1+k2...+kn 对近亲结婚。
那我们对每个家庭构造一个多项式:
c0 + c1*x + c2*x^2 + .... + cm*x^m (m = min(x, y))
把这 n 个多项式乘起来,得到的多项式的 x^k 的系数 ck 代表的就是至少有 k 对近亲结婚的方案数。
因为代表的是至少,所以最后还需要容斥一下。
n个多项式相乘,复杂度跟多项式的长度有很大关系,n个多项式的长度就是所有男孩的总数;所以复杂度其实是 o(nlognlogn)的
#include <bits/stdc++.h> #define LL long long #define ULL unsigned long long #define UI unsigned int #define mem(i, j) memset(i, j, sizeof(i)) #define rep(i, j, k) for(int i = j; i <= k; i++) #define dep(i, j, k) for(int i = k; i >= j; i--) #define pb push_back #define make make_pair #define INF 0x3f3f3f3f #define inf LLONG_MAX #define PI acos(-1) #define fir first #define sec second #define lb(x) ((x) & (-(x))) #define dbg(x) cout<<#x<<" = "<<x<<endl; using namespace std; const int N = 1e6 + 5; const LL mod = 998244353; const LL g = 3; int n, all, cnt; LL fac[N]; LL x1[N], x2[N]; vector < LL > a[N]; LL ksm(LL a, LL b) { LL res = 1LL; while(b) { if(b & 1) res = res * a % mod; a = a * a % mod; b >>= 1; } return res; } LL C(int n, int m) { return m > n ? 0 : fac[n] * ksm(fac[m] * fac[n - m] % mod, mod - 2) % mod; } struct cmp{ bool operator()(int A, int B) { return a[A].size() > a[B].size(); } }; priority_queue <LL, vector<LL>, cmp> Q; void change(LL y[], int len){ for (int i = 1, j = len / 2; i < len - 1; i++){ if (i < j) swap(y[i], y[j]); int k = len / 2; while (j >= k){ j -= k; k /= 2; } if (j < k) j += k; } } void ntt(LL y[], int len, int on){ change(y, len); for (int h = 2; h <= len; h <<= 1){ LL wn = ksm(g, (mod - 1) / h); if (on == -1) wn = ksm(wn, mod - 2); for (int j = 0; j < len; j += h){ LL w = 1ll; for (int k = j; k < j + h / 2; k++){ LL u = y[k]; LL t = w * y[k + h / 2] % mod; y[k] = (u + t) % mod; y[k + h / 2] = (u - t + mod) % mod; w = w * wn % mod; } } } if (on == -1){ LL t = ksm(len, mod - 2); rep(i, 0, len - 1) y[i] = y[i] * t % mod; } } void mul(vector <LL> &a, vector <LL> &b, vector <LL> &c){ int len = 1; int sz1 = a.size(), sz2 = b.size(); while (len <= sz1 + sz2 - 1) len <<= 1; rep(i, 0, sz1 - 1) x1[i] = a[i]; rep(i, sz1, len) x1[i] = 0; rep(i, 0, sz2 - 1) x2[i] = b[i]; rep(i, sz2, len) x2[i] = 0; ntt(x1, len, 1); ntt(x2, len, 1); rep(i, 0, len - 1) x1[i] = x1[i] * x2[i]; ntt(x1, len, -1); vector <LL>().swap(c); rep(i, 0, sz1 + sz2 - 2) c.push_back(x1[i]); } void solve() { scanf("%d", &n); rep(i, 0, n) vector < LL >().swap(a[i]); while(!Q.empty()) Q.pop(); all = 0; rep(i, 1, n) { int x, y; scanf("%d %d", &x, &y); a[i].resize(min(x,y)+1); rep(j, 0, min(x,y)) a[i][j] = C(x, j) * C(y, j) % mod * fac[j] % mod; Q.push(i); all += x; } cnt = n; rep(i, 1, n - 1) { int pos1 = Q.top(); Q.pop(); int pos2 = Q.top(); Q.pop(); mul(a[pos1], a[pos2], a[++cnt]); vector < LL >().swap(a[pos1]); vector < LL >().swap(a[pos2]); Q.push(cnt); } LL ans = 0LL, flag = 1LL; rep(i, 0, (int)(a[cnt].size()) - 1) { ans = ans + flag * fac[all - i] * a[cnt][i] % mod; ans = (ans + mod) % mod; flag *= -1; } printf("%lld\n", ans); } int main() { fac[0] = 1LL; rep(i, 1, N - 5) fac[i] = 1LL * i * fac[i - 1] % mod; int _; scanf("%d", &_); while(_--) solve(); // solve(); return 0; }
一步一步,永不停息