Live2D

Solution -「NOI Simu.」逆天题

\(\mathscr{Description}\)

  对于 \(r=0,1,\cdots,n-1\), 设 \(\{1,2,\cdots,nm\}\) 中有 \(f_r\) 个子集满足子集内元素之和 \(\bmod n=r\), 求出

\[\sum_{r=0}^{n-1}f_r^2\bmod 998244353. \]

  \(n,m\le10^{18}\).

\(\mathscr{Solution}\)

  我们知道 \(f\) 的 GF:

\[F(x)=\left(\prod_{i=0}^{n-1}(1+x^i)\right)^m\bmod (x^n-1). \]

于是我们需要用 \(\omega_n\)\(F\) 做 DFT. 不妨设 \(\hat{f_k}=\operatorname{DFT}(\boldsymbol f)_k\), 那么

\[\hat{f_k}=\left(\prod_{i=0}^{n-1}(1+\omega_n^{ik})\right)^m. \]

  看过 3B1B 的视频, 我们已知了对乘积式 \(\prod_{i=0}^{n-1}(1+\omega_n^i)\) 的处理技巧:

  注意到 \(1-z^n=0\Leftrightarrow z=\omega_n^k\), 那么 \(1-z^n=\prod_{i=0}^{n-1}(z-\omega_n^i)\). 代入 \(z=-1\), 观察左右, 我们有 \(\prod_{i=0}^{n-1}(1+\omega_n^i)=2[2\nmid n]\).

指标乘上常数 \(k\), 无非让 \(\omega_n\) 转圈圈的速度发生了一些变化. 令 \(d=\gcd(k,n)\), 容易得到

\[\hat{f_k}=[2\nmid (n/d)]2^{dm}. \]

  答案应当为 \(\boldsymbol f^2\), 也即是 \(\sum_{i=0}^{n-1}f_i^2\). 注意到 \(\gcd(k,n)=\gcd(n-k,n)\), 于是 \(f_i=f_{n-i}\), 所以只需要求

\[\sum_{i=0}^{n-1}f_if_{-i\bmod n}=[x^0]F^2(x). \]

如果能算 \(\hat{\boldsymbol f}\), 我们就能立马得到 \(\operatorname{DFT}(F^2)\), 接下来只需要提取 IDFT 矩阵的第一行 (带上 \(\frac{1}{n}\) 的系数) 与 \(F^2\) 的 DFT 向量做点乘就能得到答案.

  如何算 \(\hat{\boldsymbol f}\)? 将 \(n\) 素因数分解, 暴力枚举 \(d\) 算贡献即可. 使用 Pollard-Rho, 动态维护 \(\varphi(n/d)\), 预处理 \(2\) 的光速幂, 可以做到 \(\mathcal O(\sqrt P+Td(n))\), 其中 \(P=998244353\), \(d(n)\)\(n\) 的因子个数 (其自然比 Pollard-Rho 的 \(n^{1/4}\) 高阶).

\(\mathscr{Code}\)

/*+Rainybunny+*/

// #include <bits/stdc++.h>
#include <bits/extc++.h>

#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)

typedef long long LL;
#define fi first
#define se second

const int MOD = 998244353;
LL n, m;
__gnu_pbds::gp_hash_table<LL, int> fct;

namespace Factor {

inline LL add(LL u, const LL v, const LL m) {
    return (u += v) < m ? u : u - m;
}

inline LL mul(const LL u, const LL v, const LL m) {
    return __int128(u) * v % m;
}

inline LL mpow(LL u, LL v, LL m) {
    LL ret = 1;
    for (; v; u = mul(u, u, m), v >>= 1) ret = mul(ret, v & 1 ? u : 1, m);
    return ret;
}

inline LL gcd(const LL u, const LL v) { return v ? gcd(v, u % v) : u; }

inline LL labs(const LL u) { return u < 0 ? -u : u; }

inline bool millerRabin(const LL x, const LL b) {
    LL k = x - 1; while (!(k & 1)) k >>= 1;
    static LL pwr[70]; pwr[0] = 1, pwr[1] = mpow(b, k, x);
    while (k != x - 1) {
        pwr[pwr[0] + 1] = mul(pwr[pwr[0]], pwr[pwr[0]], x);
        ++pwr[0], k <<= 1;
    }
    per (i, pwr[0], 1) {
        if (pwr[i] != 1 && pwr[i] != x - 1) return false;
        if (pwr[i] == x - 1) return true;
    }
    return true;
}

inline bool isprime(const LL x) {
    if (x == 2 || x == 3 || x == 5 || x == 7 || x == 11) return true;
    if (x == 61 || x == 127) return true;
    if (!(x % 2) || !(x % 3) || !(x % 5) || !(x % 7) || !(x % 11))return false;
    if (!(x % 61) || !(x % 127)) return false;
    return millerRabin(x, 2) && millerRabin(x, 61) && millerRabin(x, 127);
}

inline LL pollardRho(const LL x) {
    static std::mt19937 emt(time(0) ^ 20120712);
    for (LL a = emt() % (x - 1) + 1, len = 1, st = 0, ed = 0; ;
      len <<= 1, st = ed) {
        LL prd = 1;
        rep (stp, 1, len) {
            prd = mul(prd, labs(st - (ed = add(mul(ed, ed, x), a, x))), x);
            if (!(stp & 127) && gcd(prd, x) > 1) return gcd(prd, x);
        }
        if (gcd(prd, x) > 1) return gcd(prd, x);
    }
}

inline void factor(LL x, __gnu_pbds::gp_hash_table<LL, int>& res,
  const int k = 1) {
    if (x == 1) return ;
    if (isprime(x)) return void(res[x] += k);
    LL d = pollardRho(x); int cnt = 0;
    while (!(x % d)) x /= d, ++cnt;
    factor(x, res, k), factor(d, res, k * cnt);
}

} // namespace Factor

inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int mpow(int u, int v) {
    int ret = 1;
    for (; v; u = mul(u, u), v >>= 1) ret = mul(ret, v & 1 ? u : 1);
    return ret;
}

namespace Power2 {

const int SM = 31596;
int sma[SM + 1], gnt[SM + 1];

inline void init() {
    sma[0] = 1;
    rep (i, 1, SM) addeq(sma[i] = sma[i - 1], sma[i - 1]);
    gnt[0] = 1, gnt[1] = sma[SM];
    rep (i, 2, SM) gnt[i] = mul(gnt[i - 1], gnt[1]);
}

inline int power(const int u) {
    return mul(sma[u % SM], gnt[u / SM]);
}

} // namespace Power2

inline int calc(const __gnu_pbds::gp_hash_table<LL, int>
  ::iterator&& it, LL d, LL phi) {
    if (it == fct.end()) {
        return mul(phi % MOD, Power2::power(2 * d % (MOD - 1)
          * (m % (MOD - 1)) % (MOD - 1)));
    }

    int ret = 0;
    rep (i, 0, it->se) {
        if (i == repi || it->fi != 2) addeq(ret, calc(std::next(it), d, phi));
        if (i == repi) break;
        d *= it->fi, phi /= it->fi - (i + 1 == repi);
    }
    return ret;
}

int main() {
    freopen("ntt.in", "r", stdin);
    freopen("ntt.out", "w", stdout);

    Power2::init();
    int T; scanf("%d", &T);
    while (T--) {
        scanf("%lld %lld", &n, &m);
        fct.clear(), Factor::factor(n, fct);
        LL phi = n;
        for (const auto& p: fct) phi = phi / p.fi * (p.fi - 1);
        printf("%d\n", mul(mpow(n % MOD, MOD - 2), calc(fct.begin(), 1, phi)));
    }
    return 0;
}

posted @ 2022-08-08 17:17  Rainybunny  阅读(182)  评论(3编辑  收藏  举报