ABC267Ex - Odd Sum

分治NTT

Ex - Odd Sum (atcoder.jp)

题意

给一个长度为 \(n\;(1<=n<=10^5)\) 的数组 \(A\;(A[i]<=10)\), 给定 \(M\;(1<=M<=10^6)\), 求在 \(A\) 中选 奇数 个数,满足它们的和为 \(M\) 的方案数

思路

  1. 先不考虑要选奇数个数,根据生成函数,$F=\prod (1+x^{A[i]}) $

    \(ans=[x^M]F\)

  2. 若要求奇数个,就是 \(\prod (1+x^{A[i]})\) 中只能有奇数个选择 \(x^{A[i]}\) 这一项

  3. \(G=\prod(1-x^{A[i]})\),

    \([x^M]F\) 为选择了 奇数 + 偶数 个项的方案数,\([x^M]G\) 为选择了 偶数 - 奇数 个项的方案数

  4. \(ans=[x^M]\frac {F-G}2\)

  5. 用分治NTT计算 F, G 的复杂度为 \(O(nlog^2n)\), 不一定能过,可以考虑优化

  6. 观察题目条件,\(1<=A[i]<=10\), 说明 F,G 中不同的 \(1+x^{A[i]}\) 只有 10 种,可记录下 \(cnt[i]\) 表示 \(1+x^i\) 的项的次数,用二项式定理展开,即求 F,G 的复杂度为 \(O(nlogn*log10)\)

代码

#include <bits/stdc++.h>
using namespace std;
const int md = 998244353;

inline void add(int &x, int y) {
  x += y;
  if (x >= md) {
    x -= md;
  }
}

inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += md;
  }
}

inline int mul(int x, int y) {
  return (long long) x * y % md;
}

inline int power(int x, int y) {
  int res = 1;
  for (; y; y >>= 1, x = mul(x, x)) {
    if (y & 1) {
      res = mul(res, x);
    }
  }
  return res;
}

inline int inv(int a) {
  a %= md;
  if (a < 0) {
    a += md;
  }
  int b = md, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += md;
  }
  return u;
}

namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};

void init() {
  int temp = md - 1;
  max_base = 0;
  while (temp % 2 == 0) {
    temp >>= 1;
    ++max_base;
  }
  root = 2;
  while (true) {
    if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
      break;
    }
    ++root;
  }
}

void ensure_base(int nbase) {
  if (max_base == -1) {
    init();
  }
  if (nbase <= base) {
    return;
  }
  assert(nbase <= max_base);
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) {
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (nbase - 1));
  }
  roots.resize(1 << nbase);
  while (base < nbase) {
    int z = power(root, 1 << (max_base - 1 - base));
    for (int i = 1 << (base - 1); i < 1 << base; ++i) {
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = mul(roots[i], z);
    }
    ++base;
  }
}

void dft(vector<int> &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
        a[j + k] = (x + y) % md;
        a[j + k + i] = (x + md - y) % md;
      }
    }
  }
}

vector<int> multiply(vector<int> a, vector<int> b) {
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  a.resize(sz);
  b.resize(sz);
  bool equal = a == b;
  dft(a);
  if (equal) {
    b = a;
  } else {
    dft(b);
  }
  int inv_sz = inv(sz);
  for (int i = 0; i < sz; ++i) {
    a[i] = mul(mul(a[i], b[i]), inv_sz);
  }
  reverse(a.begin() + 1, a.end());
  dft(a);
  a.resize(need);
  return a;
}

vector<int> inverse(vector<int> a) {
  int n = a.size(), m = (n + 1) >> 1;
  if (n == 1) {
    return vector<int>(1, inv(a[0]));
  } else {
    vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
    int need = n << 1, nbase = 0;
    while (1 << nbase < need) {
      ++nbase;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; ++i) {
      a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    dft(a);
    a.resize(n);
    return a;
  }
}
}

using ntt::multiply;
using ntt::inverse;

vector<int>& operator += (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    add(a[i], b[i]);
  }
  return a;
}

vector<int> operator + (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c += b;
}

vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    sub(a[i], b[i]);
  }
  return a;
}

vector<int> operator - (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c -= b;
}

vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
  if (min(a.size(), b.size()) < 128) {
    vector<int> c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < c.size(); ++i) {
      for (int j = 0; j < b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}

vector<int> operator * (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c *= b;
}

vector<int> fz_ntt(vector<vector<int> > &f, int l, int r)
{
    if (l == r)
        return f[l];
    int mid = l + r >> 1;
    auto L = fz_ntt(f, l, mid);
    auto R = fz_ntt(f, mid + 1, r);
    return L * R;
}

typedef long long ll;
const int N = 1e5 + 10;
int cnt[12];
ll fac[N], finv[N];
ll qmi(ll a, ll b)
{
    ll ans = 1;
    while(b)
    {
        if (b & 1)
            ans = ans * a % md;
        b >>= 1;
        a = a * a % md;
    }
    return ans;
}

void presolve(int n)
{
    fac[0] = finv[0] = 1;
    for (int i = 1; i <= n; i++)
        fac[i] = fac[i-1] * i % md;
    finv[n] = qmi(fac[n], md - 2);
    for (int i = n - 1; i >= 1; i--)
        finv[i] = finv[i+1] * (i + 1) % md;
}

ll C(int n, int m)
{
    if (m < 0 || n - m < 0)
        return 0;
    return fac[n] * finv[m] % md * finv[n-m] % md;
}

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
    presolve(n);
	for (int i = 1; i <= n; i++)
    {
        int x;
        scanf("%d", &x);
        cnt[x]++;
    }
    vector<vector<int> > pa, pb;
    for (int i = 1; i <= 10; i++)
    {
        vector<int> a(cnt[i] * i + 1, 0);
        vector<int> b(cnt[i] * i + 1, 0);
        for (int j = 0; j <= cnt[i]; j++)
        {
            a[j * i] = C(cnt[i], j);
            b[j * i] = (j % 2 == 0 ? a[j * i] : (md - a[j * i]));
        }
        pa.push_back(a);
        pb.push_back(b);
    }
    auto ans1 = fz_ntt(pa, 0, 9);
    auto ans2 = fz_ntt(pb, 0, 9);
    if (ans1.size() - 1 < m)
    {
        puts("0");
        return 0;
    }
    int ans = (ll)(ans1[m] - ans2[m]) * (md + 1) / 2 % md;
    if (ans < 0)
        ans += md;
    printf("%d\n", ans);
	return 0;
}
posted @ 2022-10-05 10:57  hzy0227  阅读(105)  评论(0编辑  收藏  举报