AtCoder Beginner Contest 245 Ex - Product Modulo 2

这个题解是基于官方题解,但是官方题解我看了好久才看懂,所以在官方题解的基础上增加一些解释。

建议也看看B站里dls的讲解视频。

拆分

借助CRT可以将题目拆分成多个子问题,最后再合并,就可以得到原问题的答案。

CRT的式子\(x \equiv a_i \mod m_i\),且\(m_i\)两两互素。

\(x\)看成\(N\)\(m_i = {p_{i}}^{e_i}\),这样的话CRT的条件还是满足的。这里\(M = \prod_{i} {p_i}^{e_i}\)且对于每一个\(i\)都是一个子问题。

然后可以通过一个类似exLucas的过程,将子问题的答案拼成原问题的答案。

由于各个子问题之间是相互独立的,所以将所有子问题的答案乘起来就是原问题的答案。

解决子问题

记子问题为\(f(p, e, K, N)\)

现在的问题就是要解决子问题,可以用生成函数结合快速幂来做。

首先,就是说\(N_1\)\(N_2\)的质因子分解中,\(p_i\)的指数相同,那么\(f(p, e, K, N_1) = f(p, e, K, N_2)\),这个可以用归纳法证明。所以可以将指数相同的数合并成一类,这样方便计算,最后要计算答案的时候再除以类的大小,就能得到原本的答案。

特别的,可以认为\(N = 0\)包含无穷大个\(p\)

注意到\(p^c \mod p^e\)\(c > e\)的时候都为\(0\),所以可以将\(c > e\)的归为一类。

容易证明至多有\(O(\log N)\)个类的问题。

然后,包含\(K\)个数的答案可以由包含\(K - 1\)个数的答案推导得到,就是说如果\(k = i + j\),那么包含\(k\)\(p\)的答案可以由包含\(i\)\(p\)的和包含\(j\)\(p\)的答案合并得到。

但是一个一个推太慢了,所以使用快速幂来加速,因为是线性组合,所以成立。

其实就是说构造一个生成函数\(f(x) = \sum_{i = 0}^{e} a_i x^i\),表示\(a_i\)表示包含\(i\)\(p\)的方案数,特别的\(a_e\)表示包含大于等于\(e\)\(p\)的方案数,则\(f^{K}(x)\)就是\(K\)个数相乘之后结果的方案数。

只有一个数的时候,可以方便的计算出\(a_i\)的值,就是\(a_{e} = 0, a_{e - 1} = p - 1, a_{i} = a_{i + 1} \times p\)

注意

官方题解中也有提到,就是说可能出现除\(0\)的情况,但是因为\(M \le 10^{12}\),所以这个时候指数\(e = 1\),这个可以特判一下。

\(N = 0\),那么就是\(K\)个数中任意一个为零即可,用所有方案减去全不为零的方案就是答案,即\(p^{K} - (p-1)^{K}\)

\(N \ne 0\),这个时候,前\(K - 1\)个数可以是任意非零元素,而最后一个元素是唯一确定的,即\(a_K = N \times (\prod_{i = 1}^{K - 1} a_i)^{-1}\)。所以方案数是\((p-1)^{K-1}\)

AC代码

// Problem: Ex - Product Modulo 2
// Contest: AtCoder - AtCoder Beginner Contest 245
// URL: https://atcoder.jp/contests/abc245/tasks/abc245_h
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>

#define CPPIO \
  std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
#ifdef BACKLIGHT
#include "debug.h"
#else
#define logd(...) ;
#endif

using i64 = int64_t;
using u64 = uint64_t;

void solve_case(int Case);

int main() {
  CPPIO;
  int T = 1;
  // std::cin >> T;
  for (int t = 1; t <= T; ++t) {
    solve_case(t);
  }
  return 0;
}

template <typename ValueType, ValueType mod_, typename SupperType = int64_t>
class Modular {
 private:
  ValueType value_;

  ValueType normalize(SupperType value) const {
    if (value >= 0 && value < mod_)
      return value;
    value %= mod_;
    if (value < 0)
      value += mod_;
    return value;
  }

  ValueType power(ValueType value, size_t exponent) const {
    ValueType result = 1;
    ValueType base = value;
    while (exponent) {
      if (exponent & 1)
        result = SupperType(result) * base % mod_;
      base = SupperType(base) * base % mod_;
      exponent >>= 1;
    }
    return result;
  }

 public:
  Modular() : value_(0) {}

  Modular(const SupperType& value) : value_(normalize(value)) {}

  ValueType value() const { return value_; }

  Modular inv() const { return Modular(power(value_, mod_ - 2)); }

  Modular power(size_t exponent) { return Modular(power(value_, exponent)); }

  friend Modular operator+(const Modular& lhs, const Modular& rhs) {
    ValueType result = lhs.value() + rhs.value() >= mod_
                           ? lhs.value() + rhs.value() - mod_
                           : lhs.value() + rhs.value();
    return Modular(result);
  }

  friend Modular operator-(const Modular& lhs, const Modular& rhs) {
    ValueType result = lhs.value() - rhs.value() < 0
                           ? lhs.value() - rhs.value() + mod_
                           : lhs.value() - rhs.value();
    return Modular(result);
  }

  friend Modular operator*(const Modular& lhs, const Modular& rhs) {
    ValueType result = SupperType(1) * lhs.value() * rhs.value() % mod_;
    return Modular(result);
  }

  friend Modular operator/(const Modular& lhs, const Modular& rhs) {
    ValueType result = SupperType(1) * lhs.value() * rhs.inv().value() % mod_;
    return Modular(result);
  }
};
template <typename StreamType,
          typename ValueType,
          ValueType mod,
          typename SupperType = int64_t>
StreamType& operator<<(StreamType& out,
                       const Modular<ValueType, mod, SupperType>& modular) {
  return out << modular.value();
}
template <typename StreamType,
          typename ValueType,
          ValueType mod,
          typename SupperType = int64_t>
StreamType& operator>>(StreamType& in,
                       Modular<ValueType, mod, SupperType>& modular) {
  ValueType value;
  in >> value;
  modular = Modular<ValueType, mod, SupperType>(value);
  return in;
}
// using mint = Modular<int, 1'000'000'007>;
using mint = Modular<int, 998'244'353>;
std::string to_string(mint v) {
  return to_string(v.value());
}

std::vector<std::pair<i64, i64>> factor(i64 n) {
  std::vector<std::pair<i64, i64>> pe;

  for (i64 i = 2; i * i <= n; ++i) {
    if (n % i == 0) {
      i64 e = 0;
      while (n % i == 0) {
        ++e;
        n = n / i;
      }
      pe.emplace_back(i, e);
    }
  }
  if (n > 1)
    pe.emplace_back(n, 1);
  return pe;
}

void solve_case(int Case) {
  i64 k, n, m;
  std::cin >> k >> n >> m;

  auto f = [](i64 p, i64 e, i64 k, i64 n) -> mint {
    auto cp = [](i64 n, i64 p) -> i64 {
      if (n == 0)
        return 1e9;
      i64 e = 0;
      while (n % p == 0) {
        ++e;
        n = n / p;
      }
      return e;
    };

    auto mul = [&](const std::vector<mint>& a,
                   const std::vector<mint>& b) -> std::vector<mint> {
      std::vector<mint> c(e + 1);
      for (i64 i = 0; i <= e; ++i) {
        for (i64 j = 0; j <= e; ++j) {
          c[std::min(e, i + j)] = c[std::min(e, i + j)] + a[i] * b[j];
        }
      }
      return c;
    };

    if (e == 1) {
      i64 pe = 1;
      for (int i = 1; i <= e; ++i)
        pe = pe * p;
      n = n % pe;
      if (n == 0)
        return mint(p).power(k) - mint(p - 1).power(k);
      else
        return mint(p - 1).power(k - 1);
    }

    i64 c = cp(n, p);
    if (c > e)
      c = e;

    std::vector<mint> r(e + 1), x(e + 1);
    r[0] = 1;
    x[e] = 1;
    x[e - 1] = p - 1;
    for (i64 i = e - 2; i >= 0; --i)
      x[i] = x[i + 1] * p;
    auto d = x;
    while (k) {
      if (k & 1)
        r = mul(r, x);
      x = mul(x, x);
      k >>= 1;
    }
    logd(r);

    return r[c] / d[c];
  };

  mint ans(1);
  auto pe = factor(m);
  logd(pe);
  for (auto [p, e] : pe) {
    ans = ans * f(p, e, k, n);
  }

  std::cout << ans.value() << "\n";
}
posted @ 2022-03-31 19:03  _Backl1ght  阅读(135)  评论(0编辑  收藏  举报