【luogu P4726】【模板】多项式指数函数(多项式 exp)(NTT)
【模板】多项式指数函数(多项式 exp)
题目链接:luogu P4726
题目大意
给你一个 n-1 次多项式,要你求一个 mod x^n 下的多项式使得它等于 e 的这个多项式次方。
思路
给出 \(A(x)\),要你找到 \(B(x)\) 使得 \(B(x)\equiv e^{A(x)}\pmod {x^n}\)
考虑到 \(e^x\) 是 \(ln x\) 的逆运算,我们考虑一个东西叫做牛迭。
设 \(G(B(x))=\ln B(x)-A(x)=0\)
然后 \(G'(B(x))=\dfrac{1}{B(x)}\)。
这里特别说一下,首先减法的求导是 \((F(x)-G(x))'=F'(x)-G'(x)\)。
然后右边的 \(A(x)\) 是常数所以其实可以直接去掉,因为这里的未知数是 \(B(x)\) 而不是 \(x\),\(A(x)\) 里面是 \(x\) 所以跟 \(B(x)\) 其实没有关系,所以是常数。
然后用牛迭的式子:
\(B(x)=B_*(x)-\dfrac{G(B_*(x))}{G'(B_*(x))}\)
\(B(x)=B_*(x)-G(B_*(x))B_*(x)\)
\(B(x)=(1-G(B_*(x)))B_*(x)\)
\(B(x)=(1-\ln B_*(x)+A(x))B_*(x)\)
然后就可以倍增搞了,复杂度也是 \(O(n\log n)\)。
然后要保证 \([x^0]A(x)=1\),这个时候 \([x^0]B(x)=1\),否则定义式里面会有无穷求和。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mo 998244353
#define clr(f, n) memset(f, 0, (n) * sizeof(int))
#define cpy(f, g, n) memcpy(f, g, (n) * sizeof(int))
using namespace std;
const int N = 100000 * 8 + 1;
int n, m, f[N], g[N], an[N], inv[N], G, Gv;
int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x - y < 0 ? x - y + mo : x - y;}
int cheng(int x, int y) {return 1ll * x * y % mo;}
int ksm(int x, int y) {int re = 1; while (y) {if (y & 1) re = cheng(re, x); x = cheng(x, x); y >>= 1;} return re;}
void Init() {
G = 3; Gv = ksm(G, mo - 2);
inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = cheng(inv[mo % i], mo - mo / i);
}
void get_an(int limit, int l_size) {
for (int i = 0; i < limit; i++)
an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
}
void NTT(int *f, int limit, int op) {
for (int i = 0; i < limit; i++) if (an[i] < i) swap(f[an[i]], f[i]);
for (int mid = 1; mid < limit; mid <<= 1) {
int Wn = ksm(op == 1 ? G : Gv, (mo - 1) / (mid << 1));
for (int R = (mid << 1), j = 0; j < limit; j += R) {
int w = 1;
for (int k = 0; k < mid; k++, w = cheng(w, Wn)) {
int x = f[j | k], y = cheng(w, f[j | mid | k]);
f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
}
}
}
if (op == -1) {
int limv = ksm(limit, mo - 2);
for (int i = 0; i < limit; i++) f[i] = cheng(f[i], limv);
}
}
void px(int *f, int *g, int limit) {
for (int i = 0; i < limit; i++)
f[i] = cheng(f[i], g[i]);
}
void times(int *f, int *g, int n, int m) {
static int tmp[N];
int limit = 1, l_size = 0; while (limit < n + n) limit <<= 1, l_size++;
cpy(tmp, g, n); clr(tmp + n, limit - n);
get_an(limit, l_size);
NTT(f, limit, 1); NTT(tmp, limit, 1);
px(f, tmp, limit); NTT(f, limit, -1);
clr(f + m, limit - m); clr(tmp, limit);
}
void invp(int *f, int n) {
static int w[N], r[N], tmp[N];
w[0] = ksm(f[0], mo - 2);
int limit = 1, l_size = 0;
for (int len = 2; (len >> 1) <= n; len <<= 1) {
limit = len; l_size++; get_an(limit, l_size);
cpy(r, w, len >> 1);
cpy(tmp, f, limit); NTT(tmp, limit, 1);
NTT(r, limit, 1); px(r, tmp, limit);
NTT(r, limit, -1); clr(r, limit >> 1);
cpy(tmp, w, len); NTT(tmp, limit, 1);
NTT(r, limit, 1); px(r, tmp, limit);
NTT(r, limit, -1);
for (int i = (len >> 1); i < len; i++)
w[i] = jian(cheng(w[i], 2), r[i]);
}
cpy(f, w, n); clr(w, limit); clr(r, limit); clr(tmp, limit);
}
void dao(int *f, int n) {
for (int i = 1; i < n; i++)
f[i - 1] = cheng(f[i], i);
f[n - 1] = 0;
}
void jifen(int *f, int n) {
for (int i = n; i >= 1; i--)
f[i] = cheng(f[i - 1], inv[i]);
f[0] = 0;
}
void mof(int *f, int n, int *g, int m) {
static int f_[N], g_[N];
int L = n - m + 1;
reverse(f, f + n); cpy(f_, f, L); reverse(f, f + n);
reverse(g, g + m); cpy(g_ , g, L); reverse(g, g + m);
invp(g_, L); times(g_, f_, L, L); reverse(g_, g_ + L);
times(g, g_, n, n);
for (int i = 0; i < m - 1; i++) g[i] = jian(f[i], g[i]);
clr(g + m - 1, L);
cpy(f, g_, L); clr(f + L, n - L);
}
void lnp(int *f, int n) {
static int g[N];
cpy(g, f, n); dao(g, n);
invp(f, n); times(f, g, n, n);
jifen(f, n - 1); clr(g, n);
}
void exp(int *f, int n) {
static int w[N], ww[N];
ww[0] = 1;
int len;
for (len = 2; (len >> 1) <= n; len <<= 1) {
cpy(w, ww, len >> 1); lnp(w, len);
for (int i = 0; i < len; i++)
w[i] = jian(f[i], w[i]);
w[0] = jia(w[0], 1);
times(ww, w, len, len);
}
len >>= 1;
cpy(f, ww, n); clr(ww, len); clr(w, len);
}
int main() {
Init();
scanf("%d", &n);
for (int i = 0; i < n; i++) scanf("%d", &f[i]);
exp(f, n);
for (int i = 0; i < n; i++) printf("%d ", f[i]);
return 0;
}