LOJ6503. 「雅礼集训 2018 Day4」Magic(容斥原理+NTT)

题目链接

https://loj.ac/problem/6503

题解

题中要求本质不同的序列数量,不太好搞。我们考虑给相同颜色的牌加上编号,这样所有牌都不相同。那么如果我们求出了答案,只需要将答案除以 \(\prod a_i!\) 就好了。

“恰好有 \(k\) 对”不能直接求,考虑容斥,如果我们求出了 \(g(x)\) 表示至少有 \(x\) 对的方案数,那么答案即为 \(\sum_\limits{i = k}^{n} (-1)^{i - k}\binom{i}{k}g(i)\)。现在的问题是如何求 \(g(x)\)

我们将这 \(k\) 个魔术对分配到 \(m\) 种颜色中去,即:第 \(i\) 种颜色含有 \(b_i\) 个魔术对(满足 \(b_i < a_i\))。接下来可以做一个 dp,依次考虑每一种颜色,设 \(f_{i, j}\) 表示考虑了 \(i\) 种颜色,且已有 \(j\) 个魔术对的方案数。对于 \(i\),我们需要计算第 \(i\) 种颜色有 \(b_i\) 个魔术对的方案数是多少。由于我们已经视所有牌均不同,因此我们可以从中任选 \(b_i\) 张牌来作为所有魔术对的第二张牌,方案数为 \(\binom{a_i}{b_i}\)。接下来依次将这 \(b_i\) 张牌插入到原序列中去。由于必须插到相同颜色的牌的后面,因此第一张牌有 \(a_i - b_i\) 个位置可以插入,第二张牌有 \(a_i - b_i + 1\) 个位置可以插入......因此总方案数为 \(\binom{a_i}{b_i} \times (a_i - b_i)(a_i - b_i + 1)\cdots(a_i - 1) = \binom{a_i}{b_i} \times \frac{(a_i - 1)!}{(a_i - b_i - 1)!}\)。这样,我们只需枚举 \(b_i\) 然后转移即可。

不过注意这样求得的 \(f_{m, x}\) 并不是最终的 \(g(x)\)。由于构成魔术对的 \(x\) 张牌已经确定,不构成魔术对的 \(n - x\) 张牌可以任意排列,因此 \(g(x) = f_{m, x} \times (n - x)!\)

这样做 dp 的复杂度是 \(O(nm)\) 的,可以过 64 分。

注意到上面的 dp 其实是做了 \(m\) 次卷积,因此我们可以使用 NTT 进行优化。不过直接做 \(m\) 次多项式乘法时间复杂度并不能得到优化,注意到 \(f_{i, j}\)\(j\) 一维的上界是不断合并增大的,因此我们可以使用一个堆,每次取出长度最小的两个数组进行 NTT 合并即可,这样复杂度就有保证了,时间复杂度为 \(O(m \log^2n)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define X first
#define Y second
#define mp make_pair
#define pb push_back
#define debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long ll;
typedef long double ld;
typedef unsigned int uint;
typedef pair<int, int> pii;
typedef unsigned long long ull;

template<typename T> inline void read(T& x) {
  char c = getchar();
  bool f = false;
  for (x = 0; !isdigit(c); c = getchar()) {
    if (c == '-') {
      f = true;
    }
  }
  for (; isdigit(c); c = getchar()) {
    x = x * 10 + c - '0';
  }
  if (f) {
    x = -x;
  }
}

template<typename T, typename... U> inline void read(T& x, U&... y) {
  read(x), read(y...);
}

template<typename T> inline bool checkMax(T& a, const T& b) {
  return a < b ? a = b, true : false;
}

template<typename T> inline bool checkMin(T& a, const T& b) {
  return a > b ? a = b, true : false;
}

const int N = 2e5 + 10, mod = 998244353, G = 3;

inline void add(int& x, int y) {
  x = (x + y) % mod;
}

inline void mul(int& x, int y) {
  x = 1ll * x * y % mod;
}

inline int qpow(int v, int p) {
  int res = 1;
  for (; p; p >>= 1, mul(v, v)) {
    if (p & 1) {
      mul(res, v);
    }
  }
  return res;
}

int fac[N], invfac[N];

inline int binom(int n, int m) {
  return 1ll * fac[n] * invfac[m] % mod * invfac[n - m] % mod;
}

void init(int n) {
  fac[0] = invfac[0] = 1;
  for (register int i = 1; i <= n; ++i) {
    mul(fac[i] = fac[i - 1], i);
  }
  invfac[n] = qpow(fac[n], mod - 2);
  for (register int i = n - 1; i; --i) {
    mul(invfac[i] = invfac[i + 1], i + 1);
  }
}

int a[N], b[N], l, r[N], S;

inline void ntt_init(int v) {
  for (l = 0, S = 1; S <= v; ++l, S <<= 1); --l;
  for (register int i = 0; i < S; ++i) {
    r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
  }
  memset(a, 0, (sizeof (int)) * S);
  memset(b, 0, (sizeof (int)) * S);
}

inline void ntt(int* c, int type) {
  for (register int i = 0; i < S; ++i) {
    if (i < r[i]) {
      swap(c[i], c[r[i]]);
    }
  }
  for (register int i = 1; i < S; i <<= 1) {
    int x = qpow(G, type == 1 ? (mod - 1) / (i << 1) : mod - 1 - (mod - 1) / (i << 1));
    for (register int j = 0; j < S; j += i << 1) {
      int y = 1;
      for (register int k = 0; k < i; ++k, mul(y, x)) {
        int p = c[j + k], q = 1ll * y * c[i + j + k] % mod;
        c[j + k] = (p + q) % mod;
        c[i + j + k] = (p - q + mod) % mod;
      }
    }
  }
  if (type == -1) {
    int inv = qpow(S, mod - 2);
    for (register int i = 0; i < S; ++i) {
      mul(c[i], inv);
    }
  }
}

int m, n, k, v[N], g[N];

struct State {
  int id, v;
  State () {}
  State (int id, int v): id(id), v(v) {}
  bool operator < (const State& a) const {
    return v > a.v;
  }
};

priority_queue<State> Q;

vector<int> s[N];

int main() {
  read(m, n, k);
  init(n);
  for (register int i = 1; i <= m; ++i) {
    read(v[i]), Q.push(State(i, v[i] - 1));
    for (register int j = 0; j < v[i]; ++j) {
      s[i].pb((int) (1ll * binom(v[i], j) * fac[v[i] - 1] % mod * invfac[v[i] - j - 1] % mod));
    }
  }
  for (register int i = 1; i < m; ++i) {
    State x = Q.top(); Q.pop();
    State y = Q.top(); Q.pop();
    int k = x.v + y.v;
    ntt_init(k);
    for (register int j = 0; j < s[x.id].size(); ++j) {
      a[j] = s[x.id][j];
    }
    for (register int j = 0; j < s[y.id].size(); ++j) {
      b[j] = s[y.id][j];
    }
    s[x.id].clear(), s[y.id].clear();
    ntt(a, 1), ntt(b, 1);
    for (register int j = 0; j < S; ++j) {
      mul(a[j], b[j]);
    }
    ntt(a, -1);
    for (register int j = 0; j <= k; ++j) {
      s[x.id].pb(a[j]);
    }
    Q.push(State(x.id, k));
  }
  int inv = 1;
  for (register int i = 1; i <= m; ++i) {
    mul(inv, fac[v[i]]);
  }
  inv = qpow(inv, mod - 2);
  int p = Q.top().id;
  for (register int i = 0; i < s[p].size(); ++i) {
    g[i] = 1ll * s[p][i] * fac[n - i] % mod * inv % mod;
  }
  int ans = 0;
  for (register int i = k; i <= n; ++i) {
    if (i - k & 1) {
      add(ans, mod - 1ll * binom(i, k) * g[i] % mod);
    } else {
      add(ans, 1ll * binom(i, k) * g[i] % mod);
    }
  }
  printf("%d\n", ans);
  return 0;
}
posted @ 2018-10-15 21:28  ImagineC  阅读(883)  评论(0编辑  收藏  举报