「多项式快速幂」
前置知识
基本问题
给定一个 \(n\) 次多项式 \(A(x)\),求 \(B(x)\) 满足
\[B(x)\equiv A^k(x) \mod x^n \]
用 \(\ln\) 取对数
\[\ln B(x)=\ln A^k(x)
\]
\[\ln B(x)=k\times \ln A(x)
\]
\[B(x)=e^{k\times \ln A(x)}
\]
相当于是对多项式 \(\ln\) 和 \(\exp\) 的综合应用
不过这种利用自然常数 \(e\) 做法比较局限,只能在 \(a_0=1\) 的情况下使用
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int maxn = 3e5 + 50, INF = 0x3f3f3f3f, mod = 998244353, inv3 = 332748118;
inline int read () {
register int x = 0, w = 1;
register char ch = getchar ();
for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = (1ll * x * 10 + ch - '0') % mod;
return x * w % mod;
}
inline void write (register int x) {
if (x / 10) write (x / 10);
putchar (x % 10 + '0');
}
int n, k;
int a[maxn], b[maxn], c[maxn], rev[maxn];
int res[maxn], tmp[maxn], now[maxn], typ[maxn];
inline int qpow (register int a, register int b, register int ans = 1) {
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline void NTT (register int len, register int * a, register int opt) {
for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
for (register int d = 1; d < len; d <<= 1) {
register int w1 = qpow (opt, (mod - 1) / (d << 1));
for (register int i = 0; i < len; i += d << 1) {
register int w = 1;
for (register int j = 0; j < d; j ++, w = 1ll * w * w1 % mod) {
register int x = a[i + j], y = 1ll * w * a[i + j + d] % mod;
a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
}
}
}
}
inline void Poly_Inv (register int d, register int * a, register int * b) {
if (d == 1) return b[0] = qpow (a[0], mod - 2), void ();
Poly_Inv ((d + 1) >> 1, a, b);
register int len = 1, bit = 0;
while (len < d << 1) len <<= 1, bit ++;
for (register int i = 0; i < len; i ++) res[i] = 0, rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
for (register int i = 0; i < d; i ++) res[i] = a[i];
NTT (len, res, 3), NTT (len, b, 3);
for (register int i = 0; i < len; i ++) b[i] = ((2ll * b[i] % mod - 1ll * res[i] * b[i] % mod * b[i] % mod) % mod + mod) % mod;
NTT (len, b, inv3);
register int inv = qpow (len, mod - 2);
for (register int i = 0; i < d; i ++) b[i] = 1ll * b[i] * inv % mod; for (register int i = d; i < len; i ++) b[i] = 0;
}
inline void Poly_Ln (register int d, register int * a, register int * b) {
register int len = 1, bit = 0;
while (len < d << 1) len <<= 1, bit ++;
for (register int i = 0; i < len; i ++) now[i] = tmp[i] = b[i] = 0;
for (register int i =0; i < d; i ++) now[i] = 1ll * a[i + 1] * (i + 1) % mod;
Poly_Inv (d, a, tmp), NTT (len, tmp, 3), NTT (len, now, 3);
for (register int i = 0; i < len; i ++) b[i] = 1ll * tmp[i] * now[i] % mod;
NTT (len, b, inv3);
register int inv = qpow (len, mod - 2);
for (register int i = 0; i < len; i ++) b[i] = 1ll * b[i] * inv % mod;
for (register int i = d - 1; i >= 1; i --) b[i] = 1ll * b[i - 1] * qpow (i, mod - 2) % mod;
for (register int i = d; i < len; i ++) b[i] = 0; b[0] = 0;
}
inline void Poly_Exp (register int d, register int * a, register int * b) {
if (d == 1) return b[0] = 1, void ();
Poly_Exp ((d + 1) >> 1, a, b);
register int len = 1, bit = 0;
while (len < d << 1) len <<= 1, bit ++;
for (register int i = 0; i < len; i ++) typ[i] = 0;
Poly_Ln (d, b, typ), typ[0] --;
for (register int i = 0; i < d; i ++) typ[i] = ((a[i] - typ[i]) % mod + mod) % mod;
NTT (len, typ, 3), NTT (len, b, 3);
for (register int i = 0; i < len; i ++) b[i] = 1ll * b[i] * typ[i] % mod;
NTT (len, b, inv3);
register int inv = qpow (len, mod - 2);
for (register int i = 0; i < d; i ++) b[i] = 1ll * b[i] * inv % mod; for (register int i = d; i < len; i ++) b[i] = 0;
}
inline void Poly_qpow (register int k, register int * a, register int * b) {
Poly_Ln (n + 1, a, c);
for (register int i = 0; i <= n; i ++) c[i] = 1ll * c[i] * k % mod;
Poly_Exp (n + 1, c, b);
}
int main () {
n = read() - 1, k = read();
for (register int i = 0; i <= n; i ++) a[i] = read(); Poly_qpow (k, a, b);
for (register int i = 0; i <= n; i ++) printf ("%d ", b[i]); putchar ('\n');
return 0;
}