Solution -「NOI Simu.」逆天题
\(\mathscr{Description}\)
对于 \(r=0,1,\cdots,n-1\), 设 \(\{1,2,\cdots,nm\}\) 中有 \(f_r\) 个子集满足子集内元素之和 \(\bmod n=r\), 求出
\(n,m\le10^{18}\).
\(\mathscr{Solution}\)
我们知道 \(f\) 的 GF:
于是我们需要用 \(\omega_n\) 对 \(F\) 做 DFT. 不妨设 \(\hat{f_k}=\operatorname{DFT}(\boldsymbol f)_k\), 那么
看过 3B1B 的视频, 我们已知了对乘积式 \(\prod_{i=0}^{n-1}(1+\omega_n^i)\) 的处理技巧:
注意到 \(1-z^n=0\Leftrightarrow z=\omega_n^k\), 那么 \(1-z^n=\prod_{i=0}^{n-1}(z-\omega_n^i)\). 代入 \(z=-1\), 观察左右, 我们有 \(\prod_{i=0}^{n-1}(1+\omega_n^i)=2[2\nmid n]\).
指标乘上常数 \(k\), 无非让 \(\omega_n\) 转圈圈的速度发生了一些变化. 令 \(d=\gcd(k,n)\), 容易得到
答案应当为 \(\boldsymbol f^2\), 也即是 \(\sum_{i=0}^{n-1}f_i^2\). 注意到 \(\gcd(k,n)=\gcd(n-k,n)\), 于是 \(f_i=f_{n-i}\), 所以只需要求
如果能算 \(\hat{\boldsymbol f}\), 我们就能立马得到 \(\operatorname{DFT}(F^2)\), 接下来只需要提取 IDFT 矩阵的第一行 (带上 \(\frac{1}{n}\) 的系数) 与 \(F^2\) 的 DFT 向量做点乘就能得到答案.
如何算 \(\hat{\boldsymbol f}\)? 将 \(n\) 素因数分解, 暴力枚举 \(d\) 算贡献即可. 使用 Pollard-Rho, 动态维护 \(\varphi(n/d)\), 预处理 \(2\) 的光速幂, 可以做到 \(\mathcal O(\sqrt P+Td(n))\), 其中 \(P=998244353\), \(d(n)\) 为 \(n\) 的因子个数 (其自然比 Pollard-Rho 的 \(n^{1/4}\) 高阶).
\(\mathscr{Code}\)
/*+Rainybunny+*/
// #include <bits/stdc++.h>
#include <bits/extc++.h>
#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)
typedef long long LL;
#define fi first
#define se second
const int MOD = 998244353;
LL n, m;
__gnu_pbds::gp_hash_table<LL, int> fct;
namespace Factor {
inline LL add(LL u, const LL v, const LL m) {
return (u += v) < m ? u : u - m;
}
inline LL mul(const LL u, const LL v, const LL m) {
return __int128(u) * v % m;
}
inline LL mpow(LL u, LL v, LL m) {
LL ret = 1;
for (; v; u = mul(u, u, m), v >>= 1) ret = mul(ret, v & 1 ? u : 1, m);
return ret;
}
inline LL gcd(const LL u, const LL v) { return v ? gcd(v, u % v) : u; }
inline LL labs(const LL u) { return u < 0 ? -u : u; }
inline bool millerRabin(const LL x, const LL b) {
LL k = x - 1; while (!(k & 1)) k >>= 1;
static LL pwr[70]; pwr[0] = 1, pwr[1] = mpow(b, k, x);
while (k != x - 1) {
pwr[pwr[0] + 1] = mul(pwr[pwr[0]], pwr[pwr[0]], x);
++pwr[0], k <<= 1;
}
per (i, pwr[0], 1) {
if (pwr[i] != 1 && pwr[i] != x - 1) return false;
if (pwr[i] == x - 1) return true;
}
return true;
}
inline bool isprime(const LL x) {
if (x == 2 || x == 3 || x == 5 || x == 7 || x == 11) return true;
if (x == 61 || x == 127) return true;
if (!(x % 2) || !(x % 3) || !(x % 5) || !(x % 7) || !(x % 11))return false;
if (!(x % 61) || !(x % 127)) return false;
return millerRabin(x, 2) && millerRabin(x, 61) && millerRabin(x, 127);
}
inline LL pollardRho(const LL x) {
static std::mt19937 emt(time(0) ^ 20120712);
for (LL a = emt() % (x - 1) + 1, len = 1, st = 0, ed = 0; ;
len <<= 1, st = ed) {
LL prd = 1;
rep (stp, 1, len) {
prd = mul(prd, labs(st - (ed = add(mul(ed, ed, x), a, x))), x);
if (!(stp & 127) && gcd(prd, x) > 1) return gcd(prd, x);
}
if (gcd(prd, x) > 1) return gcd(prd, x);
}
}
inline void factor(LL x, __gnu_pbds::gp_hash_table<LL, int>& res,
const int k = 1) {
if (x == 1) return ;
if (isprime(x)) return void(res[x] += k);
LL d = pollardRho(x); int cnt = 0;
while (!(x % d)) x /= d, ++cnt;
factor(x, res, k), factor(d, res, k * cnt);
}
} // namespace Factor
inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int mpow(int u, int v) {
int ret = 1;
for (; v; u = mul(u, u), v >>= 1) ret = mul(ret, v & 1 ? u : 1);
return ret;
}
namespace Power2 {
const int SM = 31596;
int sma[SM + 1], gnt[SM + 1];
inline void init() {
sma[0] = 1;
rep (i, 1, SM) addeq(sma[i] = sma[i - 1], sma[i - 1]);
gnt[0] = 1, gnt[1] = sma[SM];
rep (i, 2, SM) gnt[i] = mul(gnt[i - 1], gnt[1]);
}
inline int power(const int u) {
return mul(sma[u % SM], gnt[u / SM]);
}
} // namespace Power2
inline int calc(const __gnu_pbds::gp_hash_table<LL, int>
::iterator&& it, LL d, LL phi) {
if (it == fct.end()) {
return mul(phi % MOD, Power2::power(2 * d % (MOD - 1)
* (m % (MOD - 1)) % (MOD - 1)));
}
int ret = 0;
rep (i, 0, it->se) {
if (i == repi || it->fi != 2) addeq(ret, calc(std::next(it), d, phi));
if (i == repi) break;
d *= it->fi, phi /= it->fi - (i + 1 == repi);
}
return ret;
}
int main() {
freopen("ntt.in", "r", stdin);
freopen("ntt.out", "w", stdout);
Power2::init();
int T; scanf("%d", &T);
while (T--) {
scanf("%lld %lld", &n, &m);
fct.clear(), Factor::factor(n, fct);
LL phi = n;
for (const auto& p: fct) phi = phi / p.fi * (p.fi - 1);
printf("%d\n", mul(mpow(n % MOD, MOD - 2), calc(fct.begin(), 1, phi)));
}
return 0;
}