[学习笔记] 多项式 ln & exp
前置知识
-
\[\begin{aligned} (\ln x)' &= \frac{1}{x} \\ (\exp x)' &= x \\ \end{aligned} \]
-
复合函数的求导(链式法则)
\[(g\circ f)' (x) = g(f(x))'f'(x) \] -
多项式求逆,分治FFT。
多项式 ln
设 \(g(x) = \ln{f(x)}\),则有
\[\begin{aligned}
g'(x) = f'(x) \frac{1}{f(x)}
\end{aligned}
\]
先把 \(g'(x)\) 求出来然后再积分就行了。
多项式求导,积分都是 \(O(n)\) 的,多项式乘法为 \(O(n\log n)\),所以总复杂度为 \(O(n\log n)\)。
多项式 exp
普通方法一
设 \(g(x) = e^{f(x)}\),则有
\[\begin{aligned}
g'(x) = g(x)f'(x)
\end{aligned}
\]
写成卷积的形式就是(\(f_i\) 表示 \(f(x)\) 的 \(i\) 次项次数)
\[g'_i = \sum_{j = 0}^i g_j f_{i - j}
\]
然后因为 \(f'_i = (i + 1)f_{i + 1}\),所以就有
\[\begin{aligned}
g_{i + 1} &= \frac{1}{i + 1} \sum_{j = 0}^i g_{j} (i + 1 - j) f_{i + 1 - j} \\
g_i &= \frac{1}{i} \sum_{j = 0}^{i - 1} g_{j} (i - j) f_{i - j}
\end{aligned}
\]
可以用分治 FFT 解决,时间复杂度为 \(O(n \log^2 n)\)。
牛顿迭代
时间复杂度为 \(O(n \log n)\) ,但由于实现过程中需要求 \(\ln\),所以实际上快不了多少(至少在洛谷的模板上跑得差不多)。
代码
\(\ln\)
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353;
const int rt = 3;
int n,f[_];
int Pw(int a,int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot,num[_],inv[_],pwrt[2][_],tmp[6][_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1;
for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void NTT(int *f,int t,bool ty) {
for (int i = 1; i < t; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
if (i < num[i]) swap(f[i],f[num[i]]);
}
for (int len = 2; len <= t; len <<= 1) {
int gap = len >> 1,w1 = pwrt[ty][len],w,tmp;
for (int i = 0; i < t; i += len) {
w = 1;
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
}
if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
}
void Mul(int *f,int *g,int *h) {
for (int i = 0; i < tot; ++i) tmp[2][i] = f[i],tmp[3][i] = g[i];
NTT(tmp[2],tot,0),NTT(tmp[3],tot,0);
for (int i = 0; i < tot; ++i) h[i] = (ll)tmp[2][i] * tmp[3][i] % mod;
NTT(h,tot,1);
}
void Inv(int *f,int *h) {
for (int i = 0; i < tot; ++i) h[i] = tmp[1][i] = 0;
h[0] = Pw(f[0],mod - 2),tmp[1][0] = f[0],tmp[1][1] = f[1];
for (int len = 2,t = 4; len < tot; len <<= 1,t = (len << 1)) {
NTT(h,t,0),NTT(tmp[1],t,0);
for (int i = 0; i < t; ++i) h[i] = (ll)h[i] * (2 - (ll)h[i] * tmp[1][i] % mod + mod) % mod;
NTT(h,t,1),NTT(tmp[1],t,1);
for (int i = len; i < t; ++i) tmp[1][i] = f[i],h[i] = 0;
}
}
void Deriv(int *f,int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
void Integ(int *f,int *h) { for (int i = tot - 1; i > 0; --i) h[i] = (ll)f[i - 1] * Pw(i,mod - 2) % mod; h[0] = 0; }
void Ln(int *f,int *h) {
for (int i = 0; i < tot; ++i) tmp[4][i] = f[i];
Inv(f,tmp[4]);
Deriv(f,f);
Mul(f,tmp[4],f);
Integ(f,h);
}
}
int gi() {
int x = 0; char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0',c = getchar();
return x;
}
int main() {
n = gi();
for (int i = 0; i < n; ++i) f[i] = gi();
POLY::Init();
POLY::Ln(f,f);
for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
return 0;
}
\(exp\)(普通方法)
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353,rt = 3;
int n,g[_],f[_];
int Pw(int a,int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot,num[_],pwrt[2][_],inv[_],tmp[5][_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1;
for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void NTT(int *f,int t,bool ty) {
for (int i = 1; i < t; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
if (i < num[i]) swap(f[i],f[num[i]]);
}
for (int len = 2; len <= t; len <<= 1) {
int gap = len >> 1,w1 = pwrt[ty][len];
for (int i = 0,w = 1,tmp; i < t; i += len,w = 1)
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
}
void Mul(int *f,int *g,int *h,int t) {
memcpy(tmp[1],f,t << 2);
memcpy(tmp[2],g,t << 2);
NTT(tmp[1],t,0),NTT(tmp[2],t,0);
for (int i = 0; i < (t << 1); ++i) h[i] = (ll)tmp[1][i] * tmp[2][i] % mod;
NTT(h,t,1);
}
void dcNTT(int *f,int *g,int t,int l,int r) {
if (t == 1) { f[0] = l ? (ll)f[0] * inv[l] % mod : f[0]; return; }
dcNTT(f,g,t >> 1,l,(l + r) >> 1);
memset(tmp[0] + (t >> 1),0,t << 1);
memcpy(tmp[0],f,t << 1);
Mul(tmp[0],g,tmp[0],t);
for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + tmp[0][i - 1]) % mod;
dcNTT(f + (t >> 1),g,t >> 1,(l + r) >> 1,r);
}
void Exp(int *f,int *g) { dcNTT(f,g,tot >> 1,1,tot >> 1); }
void Deriv(int *f,int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
}
int main() {
scanf("%d",&n);
for (int i = 0; i < n; ++i) scanf("%d",&g[i]);
POLY::Init();
POLY::Deriv(g,g);
f[0] = 1;
POLY::Exp(f,g);
for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
return 0;
}
\(exp\) (牛顿迭代)
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int _ = (1 << 18) + 7;
const int mod = 998244353,rt = 3;
int n,f[_],g[_];
int Pw(int a,int p) {
int res = 1;
while (p) {
if (p & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
p >>= 1;
}
return res;
}
namespace POLY {
int tot,num[_],pwrt[2][_],inv[_];
void Init() {
tot = 1; while (tot <= n + n) tot <<= 1;
inv[1] = 1; for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
for (int len = (tot >> 1); len; len >>= 1) {
pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
}
}
void Clear(int *f,int L) { memset(f,0,L << 3); }
void NTT(int *f,int L,bool ty) {
for (int i = 1; i < L; ++i) {
num[i] = (num[i >> 1] >> 1) | ((i & 1) ? L >> 1 : 0);
if (i < num[i]) swap(f[i],f[num[i]]);
}
for (int len = 2; len <= L; len <<= 1) {
int gap = len >> 1,w1 = pwrt[ty][len];
for (int i = 0,w = 1,tmp; i < L; i += len,w = 1)
for (int j = i; j < i + gap; ++j) {
tmp = (ll)w * f[j + gap] % mod;
f[j + gap] = (f[j] - tmp + mod) % mod;
f[j] = (f[j] + tmp) % mod;
w = (ll)w * w1 % mod;
}
}
if (ty) for (int i = 0; i < L; ++i) f[i] = (ll)f[i] * inv[L] % mod;
}
void Cpy(int *h,int *f,int L) { memcpy(h,f,L << 2); }
void Inv(int *h,int *f,int L) {
int a[_],b[_];
Clear(h,L),Clear(a,L),Clear(b,L);
h[0] = Pw(f[0],mod - 2);
for (int len = 2,t = 4; len <= L; len <<= 1,t <<= 1) {
Cpy(a,f,len),Cpy(b,h,len),NTT(b,t,0),NTT(a,t,0);
for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 - (ll)a[i] * b[i] % mod + mod) % mod;
NTT(b,t,1);
for (int i = (len >> 1); i < len; ++i) h[i] = b[i];
}
}
void Deriv(int *h,int *f,int L) { for (int i = 0; i < L - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
void Integ(int *h,int *f,int L) { for (int i = L - 1; i; --i) h[i] = (ll)f[i - 1] * inv[i] % mod; h[0] = 0; }
void Ln(int *h,int *f,int L) {
int a[_],b[_];
Clear(h,L),Clear(a,L),Clear(b,L);
Deriv(a,f,L),Inv(b,f,L);
NTT(a,L << 1,0),NTT(b,L << 1,0);
for (int i = 0; i < (L << 1); ++i) h[i] = (ll)a[i] * b[i] % mod;
NTT(h,L << 1,1);
Integ(h,h,L);
}
void Exp(int *h,int *f,int L) {
int a[_],b[_],c[_];
Clear(h,L),Clear(a,L),Clear(b,L),Clear(c,L);
h[0] = 1,a[0] = f[0],a[1] = f[1];
for (int len = 2,t = 4; len <= L; len <<= 1,t <<= 1) {
Cpy(c,h,len),Ln(b,h,len),Cpy(a,f,len);
NTT(c,len,0),NTT(b,len,0),NTT(a,len,0);
for (int i = 0; i < len; ++i) c[i] = (ll)c[i] * (1ll - b[i] + a[i] + mod) % mod;
NTT(c,len,1);
for (int i = (len >> 1); i < len; ++i) h[i] = c[i];
}
}
}
int main() {
scanf("%d",&n);
for (int i = 0; i < n; ++i) scanf("%d",&g[i]);
POLY::Init();
POLY::Exp(f,g,POLY::tot >> 1);
for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
return 0;
}