LOJ6503. 「雅礼集训 2018 Day4」Magic(容斥原理+NTT)
题目链接
题解
题中要求本质不同的序列数量,不太好搞。我们考虑给相同颜色的牌加上编号,这样所有牌都不相同。那么如果我们求出了答案,只需要将答案除以 \(\prod a_i!\) 就好了。
“恰好有 \(k\) 对”不能直接求,考虑容斥,如果我们求出了 \(g(x)\) 表示至少有 \(x\) 对的方案数,那么答案即为 \(\sum_\limits{i = k}^{n} (-1)^{i - k}\binom{i}{k}g(i)\)。现在的问题是如何求 \(g(x)\)。
我们将这 \(k\) 个魔术对分配到 \(m\) 种颜色中去,即:第 \(i\) 种颜色含有 \(b_i\) 个魔术对(满足 \(b_i < a_i\))。接下来可以做一个 dp,依次考虑每一种颜色,设 \(f_{i, j}\) 表示考虑了 \(i\) 种颜色,且已有 \(j\) 个魔术对的方案数。对于 \(i\),我们需要计算第 \(i\) 种颜色有 \(b_i\) 个魔术对的方案数是多少。由于我们已经视所有牌均不同,因此我们可以从中任选 \(b_i\) 张牌来作为所有魔术对的第二张牌,方案数为 \(\binom{a_i}{b_i}\)。接下来依次将这 \(b_i\) 张牌插入到原序列中去。由于必须插到相同颜色的牌的后面,因此第一张牌有 \(a_i - b_i\) 个位置可以插入,第二张牌有 \(a_i - b_i + 1\) 个位置可以插入......因此总方案数为 \(\binom{a_i}{b_i} \times (a_i - b_i)(a_i - b_i + 1)\cdots(a_i - 1) = \binom{a_i}{b_i} \times \frac{(a_i - 1)!}{(a_i - b_i - 1)!}\)。这样,我们只需枚举 \(b_i\) 然后转移即可。
不过注意这样求得的 \(f_{m, x}\) 并不是最终的 \(g(x)\)。由于构成魔术对的 \(x\) 张牌已经确定,不构成魔术对的 \(n - x\) 张牌可以任意排列,因此 \(g(x) = f_{m, x} \times (n - x)!\)。
这样做 dp 的复杂度是 \(O(nm)\) 的,可以过 64 分。
注意到上面的 dp 其实是做了 \(m\) 次卷积,因此我们可以使用 NTT 进行优化。不过直接做 \(m\) 次多项式乘法时间复杂度并不能得到优化,注意到 \(f_{i, j}\) 中 \(j\) 一维的上界是不断合并增大的,因此我们可以使用一个堆,每次取出长度最小的两个数组进行 NTT 合并即可,这样复杂度就有保证了,时间复杂度为 \(O(m \log^2n)\)。
代码
#include<bits/stdc++.h>
using namespace std;
#define X first
#define Y second
#define mp make_pair
#define pb push_back
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long ll;
typedef long double ld;
typedef unsigned int uint;
typedef pair<int, int> pii;
typedef unsigned long long ull;
template<typename T> inline void read(T& x) {
char c = getchar();
bool f = false;
for (x = 0; !isdigit(c); c = getchar()) {
if (c == '-') {
f = true;
}
}
for (; isdigit(c); c = getchar()) {
x = x * 10 + c - '0';
}
if (f) {
x = -x;
}
}
template<typename T, typename... U> inline void read(T& x, U&... y) {
read(x), read(y...);
}
template<typename T> inline bool checkMax(T& a, const T& b) {
return a < b ? a = b, true : false;
}
template<typename T> inline bool checkMin(T& a, const T& b) {
return a > b ? a = b, true : false;
}
const int N = 2e5 + 10, mod = 998244353, G = 3;
inline void add(int& x, int y) {
x = (x + y) % mod;
}
inline void mul(int& x, int y) {
x = 1ll * x * y % mod;
}
inline int qpow(int v, int p) {
int res = 1;
for (; p; p >>= 1, mul(v, v)) {
if (p & 1) {
mul(res, v);
}
}
return res;
}
int fac[N], invfac[N];
inline int binom(int n, int m) {
return 1ll * fac[n] * invfac[m] % mod * invfac[n - m] % mod;
}
void init(int n) {
fac[0] = invfac[0] = 1;
for (register int i = 1; i <= n; ++i) {
mul(fac[i] = fac[i - 1], i);
}
invfac[n] = qpow(fac[n], mod - 2);
for (register int i = n - 1; i; --i) {
mul(invfac[i] = invfac[i + 1], i + 1);
}
}
int a[N], b[N], l, r[N], S;
inline void ntt_init(int v) {
for (l = 0, S = 1; S <= v; ++l, S <<= 1); --l;
for (register int i = 0; i < S; ++i) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
}
memset(a, 0, (sizeof (int)) * S);
memset(b, 0, (sizeof (int)) * S);
}
inline void ntt(int* c, int type) {
for (register int i = 0; i < S; ++i) {
if (i < r[i]) {
swap(c[i], c[r[i]]);
}
}
for (register int i = 1; i < S; i <<= 1) {
int x = qpow(G, type == 1 ? (mod - 1) / (i << 1) : mod - 1 - (mod - 1) / (i << 1));
for (register int j = 0; j < S; j += i << 1) {
int y = 1;
for (register int k = 0; k < i; ++k, mul(y, x)) {
int p = c[j + k], q = 1ll * y * c[i + j + k] % mod;
c[j + k] = (p + q) % mod;
c[i + j + k] = (p - q + mod) % mod;
}
}
}
if (type == -1) {
int inv = qpow(S, mod - 2);
for (register int i = 0; i < S; ++i) {
mul(c[i], inv);
}
}
}
int m, n, k, v[N], g[N];
struct State {
int id, v;
State () {}
State (int id, int v): id(id), v(v) {}
bool operator < (const State& a) const {
return v > a.v;
}
};
priority_queue<State> Q;
vector<int> s[N];
int main() {
read(m, n, k);
init(n);
for (register int i = 1; i <= m; ++i) {
read(v[i]), Q.push(State(i, v[i] - 1));
for (register int j = 0; j < v[i]; ++j) {
s[i].pb((int) (1ll * binom(v[i], j) * fac[v[i] - 1] % mod * invfac[v[i] - j - 1] % mod));
}
}
for (register int i = 1; i < m; ++i) {
State x = Q.top(); Q.pop();
State y = Q.top(); Q.pop();
int k = x.v + y.v;
ntt_init(k);
for (register int j = 0; j < s[x.id].size(); ++j) {
a[j] = s[x.id][j];
}
for (register int j = 0; j < s[y.id].size(); ++j) {
b[j] = s[y.id][j];
}
s[x.id].clear(), s[y.id].clear();
ntt(a, 1), ntt(b, 1);
for (register int j = 0; j < S; ++j) {
mul(a[j], b[j]);
}
ntt(a, -1);
for (register int j = 0; j <= k; ++j) {
s[x.id].pb(a[j]);
}
Q.push(State(x.id, k));
}
int inv = 1;
for (register int i = 1; i <= m; ++i) {
mul(inv, fac[v[i]]);
}
inv = qpow(inv, mod - 2);
int p = Q.top().id;
for (register int i = 0; i < s[p].size(); ++i) {
g[i] = 1ll * s[p][i] * fac[n - i] % mod * inv % mod;
}
int ans = 0;
for (register int i = k; i <= n; ++i) {
if (i - k & 1) {
add(ans, mod - 1ll * binom(i, k) * g[i] % mod);
} else {
add(ans, 1ll * binom(i, k) * g[i] % mod);
}
}
printf("%d\n", ans);
return 0;
}