[单位根反演] LibreOJ 6485 LJJ 学二项式定理
题目大意
\(T\) 组数据,每次给出 \(n,s,a_0,a_1,a_2,a_3\),求以下式子的值:
\[\left[\sum_{i=0}^n\left(\binom{n}{i}\cdot s^i\cdot a_{i\mathrm{\ mod\ } 4}\right)\right]\mathrm{\ mod\ } 998244353
\]
\(1\leq n\leq 10^{18},1\leq s,a_0,a_1,a_2,a_3\leq 10^8\)
题解
首先观察式子,式子中带着组合数,可能需要使用二项式定理来消除组合数。乘上的 \(a_i\) 与 \(i\mathrm{\ mod\ } 4\) 的值有关,可以枚举 \(i\mathrm{\ mod\ } 4\) 的值来分别计算每个 \(a_i\) 的贡献。然后就来推以下式子,推的过程中尽可能凑出 \(\binom{n}{i}x^iy^{n-i}\) 这样的二项式定理形式。
\[\sum_{i=0}^n\left(\binom{n}{i}\cdot s^i\cdot a_{i\mathrm{\ mod\ } 4}\right)=\sum_{k=0}^3\sum_{i=0}^n\binom{n}{i}s^ia_k[4|i-k]\\
=\frac{1}{4}\sum_{k=0}^3\sum_{i=0}^n\binom{n}{i}s^ia_k\sum_{j=0}^3\omega_4^{(i-k)j}\\
=\frac{1}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\sum_{i=0}^n\binom{n}{i}s^i\omega_4^{ij}\omega_4^{-kj}\\
=\frac{1}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\sum_{i=0}^n\binom{n}{i}s^i(\omega_4^{j})^i\\
=\frac{s^n}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\sum_{i=0}^n\binom{n}{i}\left(\frac{1}{s}\right)^{n-i}(\omega_4^{j})^i\\
=\frac{s^n}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\left(\frac{1}{s}+\omega_4^j\right)^n
\]
\(998244353\) 的原根是 \(3\),所以可以用 \(3^{\frac{998244353-1}{4}}\) 来代替 \(\omega_4\),然后直接计算即可。
时间复杂度 \(O(\log n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
template<typename elemType>
inline void Read(elemType& T) {
elemType X = 0, w = 0; char ch = 0;
while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
T = (w ? -X : X);
}
const LL MOD = 998244353;
LL qpow(LL b, LL n) {
LL x = 1, Power = b % MOD;
while (n) {
if (n & 1) x = x * Power % MOD;
Power = Power * Power % MOD;
n >>= 1;
}
return x;
}
const LL g = qpow(3, (MOD - 1) / 4);
LL n, s, a[4];
int T;
LL calc() {
LL sinv = qpow(s, MOD - 2), res = 0;
for (int k = 0;k < 4;++k) {
LL sum = 0;
for (int j = 0;j < 4;++j) {
LL temp = qpow(sinv + qpow(g, j), n) * qpow(qpow(g, k * j), MOD - 2) % MOD;
sum = (sum + temp) % MOD;
}
sum = sum * a[k] % MOD;
res = (res + sum) % MOD;
}
res = res * qpow(s, n) % MOD * qpow(4, MOD - 2) % MOD;
return res;
}
int main() {
Read(T);
while (T--) {
Read(n);Read(s);
for (int i = 0;i < 4;++i) Read(a[i]);
LL ans = calc();
printf("%lld\n", ans);
}
return 0;
}