[单位根反演] LibreOJ 6485 LJJ 学二项式定理

题目大意

\(T\) 组数据,每次给出 \(n,s,a_0,a_1,a_2,a_3\),求以下式子的值:

\[\left[\sum_{i=0}^n\left(\binom{n}{i}\cdot s^i\cdot a_{i\mathrm{\ mod\ } 4}\right)\right]\mathrm{\ mod\ } 998244353 \]

\(1\leq n\leq 10^{18},1\leq s,a_0,a_1,a_2,a_3\leq 10^8\)

题解

首先观察式子,式子中带着组合数,可能需要使用二项式定理来消除组合数。乘上的 \(a_i\)\(i\mathrm{\ mod\ } 4\) 的值有关,可以枚举 \(i\mathrm{\ mod\ } 4\) 的值来分别计算每个 \(a_i\) 的贡献。然后就来推以下式子,推的过程中尽可能凑出 \(\binom{n}{i}x^iy^{n-i}\) 这样的二项式定理形式。

\[\sum_{i=0}^n\left(\binom{n}{i}\cdot s^i\cdot a_{i\mathrm{\ mod\ } 4}\right)=\sum_{k=0}^3\sum_{i=0}^n\binom{n}{i}s^ia_k[4|i-k]\\ =\frac{1}{4}\sum_{k=0}^3\sum_{i=0}^n\binom{n}{i}s^ia_k\sum_{j=0}^3\omega_4^{(i-k)j}\\ =\frac{1}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\sum_{i=0}^n\binom{n}{i}s^i\omega_4^{ij}\omega_4^{-kj}\\ =\frac{1}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\sum_{i=0}^n\binom{n}{i}s^i(\omega_4^{j})^i\\ =\frac{s^n}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\sum_{i=0}^n\binom{n}{i}\left(\frac{1}{s}\right)^{n-i}(\omega_4^{j})^i\\ =\frac{s^n}{4}\sum_{k=0}^3a_k\sum_{j=0}^3\omega_4^{-kj}\left(\frac{1}{s}+\omega_4^j\right)^n \]

\(998244353\) 的原根是 \(3\),所以可以用 \(3^{\frac{998244353-1}{4}}\) 来代替 \(\omega_4\),然后直接计算即可。

时间复杂度 \(O(\log n)\)

Code

#include <bits/stdc++.h>
using namespace std;

#define RG register int
#define LL long long

template<typename elemType>
inline void Read(elemType& T) {
    elemType X = 0, w = 0; char ch = 0;
    while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
    while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    T = (w ? -X : X);
}

const LL MOD = 998244353;

LL qpow(LL b, LL n) {
    LL x = 1, Power = b % MOD;
    while (n) {
        if (n & 1) x = x * Power % MOD;
        Power = Power * Power % MOD;
        n >>= 1;
    }
    return x;
}

const LL g = qpow(3, (MOD - 1) / 4);
LL n, s, a[4];
int T;

LL calc() {
    LL sinv = qpow(s, MOD - 2), res = 0;
    for (int k = 0;k < 4;++k) {
        LL sum = 0;
        for (int j = 0;j < 4;++j) {
            LL temp = qpow(sinv + qpow(g, j), n) * qpow(qpow(g, k * j), MOD - 2) % MOD;
            sum = (sum + temp) % MOD;
        }
        sum = sum * a[k] % MOD;
        res = (res + sum) % MOD;
    }
    res = res * qpow(s, n) % MOD * qpow(4, MOD - 2) % MOD;
    return res;
}

int main() {
    Read(T);
    while (T--) {
        Read(n);Read(s);
        for (int i = 0;i < 4;++i) Read(a[i]);
        LL ans = calc();
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2021-08-06 14:27  AE酱  阅读(63)  评论(0编辑  收藏  举报