[学习笔记] 多项式 ln & exp

前置知识

  1. \[\begin{aligned} (\ln x)' &= \frac{1}{x} \\ (\exp x)' &= x \\ \end{aligned} \]

  2. 复合函数的求导(链式法则)

    \[(g\circ f)' (x) = g(f(x))'f'(x) \]

  3. 多项式求逆,分治FFT。

多项式 ln

\(g(x) = \ln{f(x)}\),则有

\[\begin{aligned} g'(x) = f'(x) \frac{1}{f(x)} \end{aligned} \]

先把 \(g'(x)\) 求出来然后再积分就行了。

多项式求导,积分都是 \(O(n)\) 的,多项式乘法为 \(O(n\log n)\),所以总复杂度为 \(O(n\log n)\)

多项式 exp

普通方法一

\(g(x) = e^{f(x)}\),则有

\[\begin{aligned} g'(x) = g(x)f'(x) \end{aligned} \]

写成卷积的形式就是(\(f_i\) 表示 \(f(x)\)\(i\) 次项次数)

\[g'_i = \sum_{j = 0}^i g_j f_{i - j} \]

然后因为 \(f'_i = (i + 1)f_{i + 1}\),所以就有

\[\begin{aligned} g_{i + 1} &= \frac{1}{i + 1} \sum_{j = 0}^i g_{j} (i + 1 - j) f_{i + 1 - j} \\ g_i &= \frac{1}{i} \sum_{j = 0}^{i - 1} g_{j} (i - j) f_{i - j} \end{aligned} \]

可以用分治 FFT 解决,时间复杂度为 \(O(n \log^2 n)\)

牛顿迭代

[学习笔记]牛顿迭代

时间复杂度为 \(O(n \log n)\) ,但由于实现过程中需要求 \(\ln\),所以实际上快不了多少(至少在洛谷的模板上跑得差不多)。

代码

\(\ln\)

#include <cstdio>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353;
const int rt = 3;

int n,f[_];

int Pw(int a,int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot,num[_],inv[_],pwrt[2][_],tmp[6][_];
  
  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1;
    for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void NTT(int *f,int t,bool ty) {
    for (int i = 1; i < t; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
      if (i < num[i]) swap(f[i],f[num[i]]);
    }
    for (int len = 2; len <= t; len <<= 1) {
      int gap = len >> 1,w1 = pwrt[ty][len],w,tmp;
      for (int i = 0; i < t; i += len) {
        w = 1;
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
      }
    }
    if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
  }

  void Mul(int *f,int *g,int *h) {
    for (int i = 0; i < tot; ++i) tmp[2][i] = f[i],tmp[3][i] = g[i];
    NTT(tmp[2],tot,0),NTT(tmp[3],tot,0);
    for (int i = 0; i < tot; ++i) h[i] = (ll)tmp[2][i] * tmp[3][i] % mod;
    NTT(h,tot,1);
  }

  void Inv(int *f,int *h) {
    for (int i = 0; i < tot; ++i) h[i] = tmp[1][i] = 0;
    h[0] = Pw(f[0],mod - 2),tmp[1][0] = f[0],tmp[1][1] = f[1];
    for (int len = 2,t = 4; len < tot; len <<= 1,t = (len << 1)) {
      NTT(h,t,0),NTT(tmp[1],t,0);
      for (int i = 0; i < t; ++i) h[i] = (ll)h[i] * (2 - (ll)h[i] * tmp[1][i] % mod + mod) % mod;
      NTT(h,t,1),NTT(tmp[1],t,1);
      for (int i = len; i < t; ++i) tmp[1][i] = f[i],h[i] = 0;
    }
  }

  void Deriv(int *f,int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }

  void Integ(int *f,int *h) { for (int i = tot - 1; i > 0; --i) h[i] = (ll)f[i - 1] * Pw(i,mod - 2) % mod; h[0] = 0; }

  void Ln(int *f,int *h) {
    for (int i = 0; i < tot; ++i) tmp[4][i] = f[i];
    Inv(f,tmp[4]);
    Deriv(f,f);
    Mul(f,tmp[4],f);
    Integ(f,h);
  }
}

int gi() {
  int x = 0; char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0',c = getchar();
  return x;
}

int main() {
  n = gi();
  for (int i = 0; i < n; ++i) f[i] = gi();
  POLY::Init();
  POLY::Ln(f,f);
  for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
  return 0;
}

\(exp\)(普通方法)

#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353,rt = 3;

int n,g[_],f[_];

int Pw(int a,int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot,num[_],pwrt[2][_],inv[_],tmp[5][_];

  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1;
    for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void NTT(int *f,int t,bool ty) {
    for (int i = 1; i < t; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
      if (i < num[i]) swap(f[i],f[num[i]]);
    }
    for (int len = 2; len <= t; len <<= 1) {
      int gap = len >> 1,w1 = pwrt[ty][len];
      for (int i = 0,w = 1,tmp; i < t; i += len,w = 1)
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
    }
    if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
  }

  void Mul(int *f,int *g,int *h,int t) {
    memcpy(tmp[1],f,t << 2);
    memcpy(tmp[2],g,t << 2);
    NTT(tmp[1],t,0),NTT(tmp[2],t,0);
    for (int i = 0; i < (t << 1); ++i) h[i] = (ll)tmp[1][i] * tmp[2][i] % mod;
    NTT(h,t,1);
  }

  void dcNTT(int *f,int *g,int t,int l,int r) {
    if (t == 1) { f[0] = l ? (ll)f[0] * inv[l] % mod : f[0]; return; }
    dcNTT(f,g,t >> 1,l,(l + r) >> 1);
    memset(tmp[0] + (t >> 1),0,t << 1);
    memcpy(tmp[0],f,t << 1);
    Mul(tmp[0],g,tmp[0],t);
    for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + tmp[0][i - 1]) % mod;
    dcNTT(f + (t >> 1),g,t >> 1,(l + r) >> 1,r);
  }

  void Exp(int *f,int *g) { dcNTT(f,g,tot >> 1,1,tot >> 1); }

  void Deriv(int *f,int *h) { for (int i = 0; i < tot - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
}

int main() {
  scanf("%d",&n);
  for (int i = 0; i < n; ++i) scanf("%d",&g[i]);
  POLY::Init();
  POLY::Deriv(g,g);
  f[0] = 1;
  POLY::Exp(f,g);
  for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
  return 0;
}

\(exp\) (牛顿迭代)

#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

typedef long long ll;

const int _ = (1 << 18) + 7;
const int mod = 998244353,rt = 3;

int n,f[_],g[_];

int Pw(int a,int p) {
  int res = 1;
  while (p) {
    if (p & 1) res = (ll)res * a % mod;
    a = (ll)a * a % mod;
    p >>= 1;
  }
  return res;
}

namespace POLY {
  int tot,num[_],pwrt[2][_],inv[_];

  void Init() {
    tot = 1; while (tot <= n + n) tot <<= 1;
    inv[1] = 1; for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
    pwrt[0][tot] = Pw(rt,(mod - 1) / tot);
    pwrt[1][tot] = Pw(pwrt[0][tot],mod - 2);
    for (int len = (tot >> 1); len; len >>= 1) {
      pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
      pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
    }
  }

  void Clear(int *f,int L) { memset(f,0,L << 3); }

  void NTT(int *f,int L,bool ty) {
    for (int i = 1; i < L; ++i) {
      num[i] = (num[i >> 1] >> 1) | ((i & 1) ? L >> 1 : 0);
      if (i < num[i]) swap(f[i],f[num[i]]);
    }
    for (int len = 2; len <= L; len <<= 1) {
      int gap = len >> 1,w1 = pwrt[ty][len];
      for (int i = 0,w = 1,tmp; i < L; i += len,w = 1)
        for (int j = i; j < i + gap; ++j) {
          tmp = (ll)w * f[j + gap] % mod;
          f[j + gap] = (f[j] - tmp + mod) % mod;
          f[j] = (f[j] + tmp) % mod;
          w = (ll)w * w1 % mod;
        }
    }
    if (ty) for (int i = 0; i < L; ++i) f[i] = (ll)f[i] * inv[L] % mod;
  }

  void Cpy(int *h,int *f,int L) { memcpy(h,f,L << 2); }

  void Inv(int *h,int *f,int L) {
    int a[_],b[_];
    Clear(h,L),Clear(a,L),Clear(b,L);
    h[0] = Pw(f[0],mod - 2);
    for (int len = 2,t = 4; len <= L; len <<= 1,t <<= 1) {
      Cpy(a,f,len),Cpy(b,h,len),NTT(b,t,0),NTT(a,t,0);
      for (int i = 0; i < t; ++i) b[i] = (ll)b[i] * (2 - (ll)a[i] * b[i] % mod + mod) % mod;
      NTT(b,t,1);
      for (int i = (len >> 1); i < len; ++i) h[i] = b[i];
    }
  }

  void Deriv(int *h,int *f,int L) { for (int i = 0; i < L - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
  void Integ(int *h,int *f,int L) { for (int i = L - 1; i; --i) h[i] = (ll)f[i - 1] * inv[i] % mod; h[0] = 0; }

  void Ln(int *h,int *f,int L) {
    int a[_],b[_];
    Clear(h,L),Clear(a,L),Clear(b,L);
    Deriv(a,f,L),Inv(b,f,L);
    NTT(a,L << 1,0),NTT(b,L << 1,0);
    for (int i = 0; i < (L << 1); ++i) h[i] = (ll)a[i] * b[i] % mod;
    NTT(h,L << 1,1);
    Integ(h,h,L);
  }

  void Exp(int *h,int *f,int L) {
    int a[_],b[_],c[_];
    Clear(h,L),Clear(a,L),Clear(b,L),Clear(c,L);
    h[0] = 1,a[0] = f[0],a[1] = f[1];
    for (int len = 2,t = 4; len <= L; len <<= 1,t <<= 1) {
      Cpy(c,h,len),Ln(b,h,len),Cpy(a,f,len);
      NTT(c,len,0),NTT(b,len,0),NTT(a,len,0);
      for (int i = 0; i < len; ++i) c[i] = (ll)c[i] * (1ll - b[i] + a[i] + mod) % mod;
      NTT(c,len,1);
      for (int i = (len >> 1); i < len; ++i) h[i] = c[i];
    }
  }
}

int main() {
  scanf("%d",&n);
  for (int i = 0; i < n; ++i) scanf("%d",&g[i]);
  POLY::Init();
  POLY::Exp(f,g,POLY::tot >> 1);
  for (int i = 0; i < n; ++i) printf("%d ",f[i]); putchar('\n');
  return 0;
}
posted @ 2020-12-03 15:22  BruceW  阅读(397)  评论(0编辑  收藏  举报