【知识总结】多项式全家桶(二)(ln和exp)
上一篇:【知识总结】多项式全家桶(一)(NTT、加减乘除和求逆)
一、对数函数\(\ln(A)\)
求一个多项式\(B(x)\),满足\(B(x)=\ln(A(x))\)。
这里需要一些最基本的微积分知识(不会?戳我(暂时戳不动):【知识总结】微积分初步挖坑待填)。
另外,\(n\)次多项式\(A(x)\)可以看成关于\(x\)的\(n\)次函数,可以对其求导。显然,\(A(x)=\sum\limits_{i=0}^{n-1}a_ix^i\)的导数是\(A'(x)=\sum\limits_{i=0}^{n-2}a_{i+1}x^i(i+1)\),积分是\(\int A(x)\mathrm{d} x=\sum\limits_{i=1}^{n}\frac{a_{i-1}}{i}x^i\)。可以写出如下代码(非常简单):
void derivative(const int *a, int *b, const int n)
{
for (int i = 1; i < n; i++)
b[i - 1] = (ll)a[i] * i % p;
b[n - 1] = 0;
}
void integral(const int *a, int *b, const int n)
{
for (int i = n - 1; i >= 0; i--)
b[i + 1] = (ll)a[i] * inv(i + 1) % p;
b[0] = 0;
}
\(f(x)=\ln(x)\)的导数是\(f'(x)=\frac{1}{x}\)。回到原问题,对两边同时求导,得到(要用一下链式法则\(g(f(x))\)的导数是\(g'(f(x))f'(x)\)):
求个\(A(x)\)的逆元(多项式求逆)然后乘上\(A'(x)\),最后把\(B(x)\)积分回去就好了。
至于代码……下面算多项式指数函数的时候要算对数函数,所以暂时省略。
二、指数函数\(\exp(x)\)
求多项式\(B(x)\)满足\(B(x)=e^{A(x)}\)。
首先,这个式子相当于求\(\ln B(x)=A(x)\)即\(\ln B(x)-A(x)=0\)
设关于多项式的函数\(F(B(x))=\ln B(x)-A(x)\),那么问题就是求这个函数的零点(\(A(x)\)是给定的,视作常数)。
求函数零点的方法之一是牛顿迭代,公式如下(\(i\)是迭代次数,\(x\)是自变量,\(F(x)\)是要求零点的函数,\(F'(x_0)\)是\(F(x)\)在\(x_0\)处的导数):
把\(F(B(x))=\ln B(x)-A(x)\)求导,得到\(F'(B(x))=\frac{1}{B(x)}\)(注意自变量是\(B(x)\)不是\(x\)。这不是一个\(F(x)\)和\(B(x)\)的复合函数)。然后代入上面的公式:
由于多项式乘法的存在,每迭代一次\(B\)的有效长度会增加一倍。
下一篇:【知识总结】多项式全家桶(三)(任意模数NTT)
代码(洛谷4726):
#include <cstdio>
#include <algorithm>
#include <cctype>
#include <cstring>
#undef i
#undef j
#undef k
#undef true
#undef false
#undef min
#undef max
#undef swap
#undef sort
#undef if
#undef for
#undef while
#undef printf
#undef scanf
#undef putchar
#undef getchar
#define _ 0
using namespace std;
namespace zyt
{
template<typename T>
inline bool read(T &x)
{
char c;
bool f = false;
x = 0;
do
c = getchar();
while (c != EOF && c != '-' && !isdigit(c));
if (c == EOF)
return false;
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
return true;
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
typedef long long ll;
const int N = 1e5 + 10, LEN = (N << 2), p = 998244353, g = 3;
namespace Polynomial
{
inline int power(int a, int b)
{
int ans = 1;
while (b)
{
if (b & 1)
ans = (ll)ans * a % p;
a = (ll)a * a % p;
b >>= 1;
}
return ans;
}
inline int inv(const int a)
{
return power(a, p - 2);
}
int omega[LEN], winv[LEN], rev[LEN];
void init(const int n, const int lg2)
{
int w = power(g, (p - 1) / n), wi = inv(w);
omega[0] = winv[0] = 1;
for (int i = 1; i < n; i++)
{
omega[i] = (ll)omega[i - 1] * w % p;
winv[i] = (ll)winv[i - 1] * wi % p;
}
for (int i = 0; i < n; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
}
void ntt(int *a, const int *w, const int n)
{
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int l = 1; l < n; l <<= 1)
for (int i = 0; i < n; i += (l << 1))
for (int k = 0; k < l; k++)
{
int tmp = (a[i + k] - (ll)w[n / (l << 1) * k] * a[i + l + k] % p + p) % p;
a[i + k] = (a[i + k] + (ll)w[n / (l << 1) * k] * a[i + l + k] % p) % p;
a[i + l + k] = tmp;
}
}
void mul(const int *a, const int *b, int *c, const int n)
{
static int x[LEN], y[LEN];
int m = 1, lg2 = 0;
while (m < (n << 1) - 1)
m <<= 1, ++lg2;
init(m, lg2);
memcpy(x, a, sizeof(int[n]));
memset(x + n, 0, sizeof(int[m - n]));
memcpy(y, b, sizeof(int[n]));
memset(y + n, 0, sizeof(int[m - n]));
ntt(x, omega, m), ntt(y, omega, m);
for (int i = 0; i < m; i++)
x[i] = (ll)x[i] * y[i] % p;
ntt(x, winv, m);
int invm = inv(m);
for (int i = 0; i < m; i++)
x[i] = (ll)x[i] * invm % p;
memcpy(c, x, sizeof(int[n]));
}
void _inv(const int *a, int *b, const int n)
{
if (n == 1)
b[0] = inv(a[0]);
else
{
static int tmp[LEN];
_inv(a, b, (n + 1) >> 1);
int m = 1, lg2 = 0;
while (m < (n << 1) + 1)
m <<= 1, ++lg2;
init(m, lg2);
memcpy(tmp, a, sizeof(int[n]));
memset(tmp + n, 0, sizeof(int[m - n]));
memset(b + ((n + 1) >> 1), 0, sizeof(int[m - ((n + 1) >> 1)]));
ntt(tmp, omega, m);
ntt(b, omega, m);
for (int i = 0; i < m; i++)
b[i] = (b[i] * 2LL % p - (ll)tmp[i] * b[i] % p * b[i] % p + p) % p;
ntt(b, winv, m);
int invm = inv(m);
for (int i = 0; i < m; i++)
b[i] = (ll)b[i] * invm % p;
memset(b + n, 0, sizeof(int[m - n]));
}
}
void inv(const int *a, int *b, const int n)
{
static int tmp[LEN];
memcpy(tmp, a, sizeof(int[n]));
_inv(tmp, b, n);
}
void derivative(const int *a, int *b, const int n)
{
for (int i = 1; i < n; i++)
b[i - 1] = (ll)a[i] * i % p;
b[n - 1] = 0;
}
void integral(const int *a, int *b, const int n)
{
for (int i = n - 1; i >= 0; i--)
b[i + 1] = (ll)a[i] * inv(i + 1) % p;
b[0] = 0;
}
void ln(const int *a, int *b, const int n)
{
static int tmp[LEN], inva[LEN];
derivative(a, tmp, n);
inv(a, inva, n - 1);
mul(inva, tmp, b, n - 1);
integral(b, b, n - 1);
}
void _exp(const int *a, int *b, const int n)
{
if (n == 1)
b[0] = 1;
else
{
static int tmp[LEN];
_exp(a, b, (n + 1) >> 1);
ln(b, tmp, n);
for (int i = 0; i < n; i++)
tmp[i] = (-tmp[i] + a[i] + p) % p;
tmp[0] = (tmp[0] + 1) % p;
mul(b, tmp, b, n);
}
}
void exp(const int *a, int *b, const int n)
{
static int tmp[LEN];
memcpy(tmp, a, sizeof(int[n]));
_exp(tmp, b, n);
}
}
int work()
{
static int a[LEN];
int n;
read(n);
for (int i = 0; i < n; i++)
read(a[i]);
Polynomial::exp(a, a, n);
for (int i = 0; i < n; i++)
write(a[i]), putchar(' ');
return (0^_^0);
}
}
int main()
{
return zyt::work();
}