【模板】多项式全家桶(多项式初等函数)

代码实现在最底下。

【模板】多项式初等函数

同时作为 https://github.com/caijianhong/template-poly 的 document。

杂项

数域为 F998244353,所以定义了 mintmodint<998244353>

poly 是多项式的类型,从 std::vector<mint> 继承而来。poly 的构造函数如下:

poly();
explicit poly(int n); // n 为项数,n 个 0
poly(const vector<mint>& vec); // vec[0] 为常数项,vec[1] 为一次项,以此类推
poly(initializer_list<mint> il);

以及

  • poly& poly::cut(int lim); 效果等同于截断到 lim 项或补零至 lim 项后返回自己。

  • istream& operator>>(istream& is, poly& a); 输入一个多项式,输入恰好 a.size() 个 modint。

  • ostream& operator<<(ostream& os, const poly& a); 输出一个多项式,以空格分割,最后没有换行。

  • poly operator<<(poly a, const int& k);poly operator>>(poly a, const int& k); 分别是乘 xk 和除 xk<k 次项舍弃)。有对应的 operator<<=operator>>=

  • void poly::ntt(int op); 是 NTT:op=1 是 DFT,op=1 是 IDFT。实现是纯暴力。

  • poly concalc(int n, vector<poly> vec, const function<mint(vector<mint>)>& func); 这个接口主要用于实现牛顿迭代,n 是最高次数,vec 是若干多项式,func 是一个计算的回调函数,如计算多项式乘法是这样的:

    • concalc(len, {a, b}, [](vector<mint> vec) { return vec[0] * vec[1]; });
    • 即计算 aba,b 都是多项式。

多项式单点求值

问题

给出有限项的多项式 F(x)x0,求 F(x0)

mint poly::operator()(const mint& x) const;

solution

秦九韶算法。即从最高项开始,每次做形如 ans=ans×x+ai 的工作。O(n)

多项式加法、减法、数乘

问题

给出有限项的多项式 F(x),G(x)λ,求 H(x)=F(x)±G(x)H(x)=λF(x)

poly operator+(poly a, const poly& b);
poly operator-(poly a, const poly& b);
poly operator*(poly a, const mint& k);
poly operator*(const mint& k, poly a);
poly operator/(poly a, const mint& k);

有对应的 operator+=operator-=operator*=operator/=

solution

对应位相加、相减、数乘。O(n)

多项式乘法

问题

给出多项式 F(x),G(x),求 H(x)=F(x)G(x)

poly operator*(const poly& a, const poly& b);

有对应的 operator*=

solution

https://www.cnblogs.com/caijianhong/p/template-fft.htmlO(nlogn)

code

便于记背。

typedef modint<998244353> mint;
int glim(const int& x) { return 1 << (32 - __builtin_clz(x - 1)); }
int bitctz(const int& x) { return __builtin_ctz(x); }
void poly::ntt(int op) {
  static bool wns_flag = false;
  static vector<mint> wns;
  if (!wns_flag) {
    wns_flag = true;
    for (int j = 1; j <= 23; j++) {
      wns.push_back(qpow(mint(3), raw(mint(-1)) >> j));
    }
  }
  vector<mint>& a = *this;
  int n = a.size();
  for (int i = 1, r = 0; i < n; i++) {
    r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
    if (i < r) std::swap(a[i], a[r]);
  }
  vector<mint> w(n);
  for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
    mint wn = wns[bitctz(k)];
    for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
    for (int i = 0; i < n; i += len) {
      for (int j = 0; j < k; j++) {
        mint x = a[i + j], y = a[i + j + k] * w[j];
        a[i + j] = x + y, a[i + j + k] = x - y;
      }
    }
  }
  if (op == -1) {
    mint iz = mint(1) / n;
    for (int i = 0; i < n; i++) a[i] *= iz;
    reverse(a.begin() + 1, a.end());
  }
}

多项式乘法逆

问题

给出 F(x),求 H(x)modxlim 满足 H(x)F(x)1(modxlim)

注意,此处 H(x) 是无限项的多项式,我们只需要 H(x)modxlim。另外需要保证 F(0)0,否则逆元在 F998244353 上不存在(如果是其他的域,要求 F(0) 存在逆元)。

poly getInv(const poly& a, int lim);

Newton's Method

给出 G(H(x)),我们需要找到 H(x) 使得 G(H(x))=0

n 为偶数,已经知道了 H(x)=H(x)modxn/2 满足 G(H(x))0(modxn/2)H(0) 需要特殊计算)。想知道 H(x)modxn

H(x)=H(x) 处对 G(H(x)) 作泰勒展开。

G(H(x))=i=0+G(i)(H(x))i!(H(x)H(x))i=0

上式,若两边 modxn,因为 H(x)H(x) 的前 n/2 项系数全零,所以 (H(x)H(x))ii2 时是零。

G(H(x))i=0+G(i)(H(x))i!(H(x)H(x))i0(modxn)

G(H(x))G(H(x))+G(H(x))(H(x)H(x))0(modxn)

所以

G(H(x))+G(H(x))H(x)G(H(x))H(x)(modxn)

H(x)H(x)G(H(x))G(H(x))(modxn)

注意这个 G 是一个导数,我们最好指明它导的是 H(x) 而不是 x。这意味着与 H(x) 无关的项应视作常数。

H(x)H(x)G(H(x))ddH(x)G(H(x))(modxn)

solution

需要找到 H(x),使得 G(H(x))=1H(x)F(x)=0。(注意这个构造)

H(x)H(x)1H(x)F(x)1H2(x)(modxn)

这里 F(x)H(x) 无关,是常数,求导时消失了。

H(x)2H(x)H2(x)F(x)(modxn)

时间复杂度 T(n)=T(n/2)+O(nlogn)=O(nlogn)

多项式除法与取模(整除)

问题

给定一个 n 次多项式 F(x) 和一个 m 次多项式 G(x) ,请求出多项式 Q(x), R(x),满足以下条件:

  • Q(x) 次数为 nmR(x) 次数小于 m(钦定,R(x) 的次数恰好为 m1
  • F(x)=Q(x)G(x)+R(x)
poly operator/(poly a, poly b);
poly operator%(const poly& a, const poly& b);

有对应的 operator/=operator%=

solution

R(x)=0,则直接使用多项式求逆解决。不妨,定义

FR(x)=xnF(1x)

注意,nF(x) 最高次项的次数。发现这样 x0 项系数成为 xn 项系数,xn 项系数成为 x0 项系数,系数的顺序翻转了。而且有 (FR)R(x)=F(x)

F(1/x)=Q(1/x)G(1/x)+R(1/x)

xnF(1/x)=xnmQ(1/x)xmG(1/x)+xnm+1xm1R(1/x)

FR(x)=QR(x)GR(x)+xnm+1RR(x)

FR(x)QR(x)GR(x)+xnm+1RR(x)(modxnm+1)

FR(x)QR(x)GR(x)(modxnm+1)

消除了 R(x) 的影响,一次求逆和一次乘法可以求出 Q(x)。然后有 R(x)=F(x)Q(x)G(x)

O(nlogn)

多项式形式导数与不定积分

问题

给出 F(x),求 H(x) 满足 H(x)=ddxF(x)(形式导数)或者 H(x)=F(x)dx(形式不定积分)。

这里默认形式不定积分的常数项为 0,尽管应该是任意常数。

poly getDev(poly a); // 形式导数
poly getInt(poly a); // 形式不定积分

solution

ddxxk=kxk1xkdx=1k+1xk+1

求导和不定积分都有线性性,每项分开计算。O(n)

多项式对数函数(ln)

问题

给出 F(x),求 H(x)modxlim 满足 H(x)lnF(x)(modxlim)

注意,此处 H(x) 是无限项的多项式,我们只需要 H(x)modxlim。另外需要保证 F(0)=1,否则在 F998244353 上不存在。

poly getLn(const poly& a, int lim);

级数表示

给出 ln 的麦克劳林级数:

ln(1+x)=i=1+(1)i+1ixi

ln(1x)=i=1+xii

要点:ln(1+x) 的关于 x 的导数是复合函数求导,是 11+x,二阶导也是复合函数,是 1(1+x)2

ln(x)=i=1+(1)i+1i(x1)i

链式法则

复合函数求导的链式法则:(F(G(x)))=F(G(x))G(x)。具体来说是

ddxF(G(x))=ddG(x)F(G(x))ddxG(x)

solution

lnF(x) 求导再积分得到

ddxlnF(x)=ddxF(x)F(x)lnF(x)=ddxlnF(x)dx=ddxF(x)F(x)dx

lnF(x)=F(x)F(x)dx

注意不要把 dx 约掉

多项式对数函数(exp)

问题

给出 F(x),求 H(x)modxlim 满足 H(x)expF(x)(modxlim)

注意,此处 H(x) 是无限项的多项式,我们只需要 H(x)modxlim。另外需要保证 F(0)=0,否则在 F998244353 上不存在。

poly getExp(const poly& a, int lim);

级数表示

给出 exp 的麦克劳林级数:

expx=i=0+xii!

solution

需要找到 H(x),使得 G(H(x))=lnH(x)F(x)=0。应用 Newton's Method。

H(x)H(x)lnH(x)F(x)1H(x)(modxn)

H(x)H(x)(1lnH(x)+F(x))(modxn)

时间复杂度 T(n)=T(n/2)+O(nlogn)=O(nlogn)

多项式快速幂

问题

给出 F(x)k,求 H(x)modxlim 满足 H(x)Fk(x)(modxlim)。注意,此处 H(x) 是有限项但真的很多项的多项式,我们只需要 H(x)modxlim

poly qpow(const poly& a, string k, int lim); // k 是高精度数
poly qpow(const poly& a, LL k, int lim);

solution

H(x)=Fk(x)=exp(lnFk(x))=exp(klnF(x))

需要通过微调使得 F(0)=1,首先将 F(x)x 直到有常数项,然后所有系数除掉 F(0)。最后再搞回去。这里注意的是取模问题,有两个定理帮助我们:

定理一

p 是质数。来源:https://www.luogu.com.cn/article/gkt0ryc3

Fp(x)F(xp)(modp)

证明略。反正可以得出,因为 n<p,所以 Fp(x)F(xp)F(0)1(modxn)(在 Fp 上)。

所以在对 lnF(x) 点乘 k 时,k 可以对 p 取模,这里的 p=998244353

定理二

p 是质数,a0(modp),费马小定理:

ap11(modp)

所以将常数项乘回去时,k 要对 p1 取模,即最终的多项式要乘 Fkmod(p1)(0) 再乘上 x 的若干次方。

多项式开根

问题

给出 F(x),求 H(x)modxlim 满足 H2(x)F(x)(modxlim)。注意,此处 H(x) 是无限项的多项式,我们只需要 H(x)modxlim。另外需要保证 F(0) 在这个域上有二次剩余。

mint sqrt(const mint& c); // 一个数求二次剩余,无解抛出异常
poly getSqrt(const poly& a, int lim);

solution

需要找到 H(x),使得 G(H(x))=H2(x)F(x)=0。应用 Newton's Method。

H(x)H(x)H2(x)F(x)2H(x)(modxn)

H(x)H(x)2+F(x)2H(x)(modxn)

时间复杂度 T(n)=T(n/2)+O(nlogn)=O(nlogn)

拉格朗日插值

问题

n 个点 (xi,yi) 可以唯一地确定一个 n1 次多项式 f(x)。现在,给定这 n 个点,请你确定这个多项式。

poly lagrange(const vector<pair<mint, mint>>& a);

solution

定义

j(x)=ijxxixjxi

f(x)=ii(x)yi

正确性比较显然。具体实现时先暴力求出 i(xxi),然后每次 O(n) 的除掉一个二项式。O(n2)

Berlekamp–Massey 算法(最短线性递推式)

问题

给出一个数列 P0 开始的前 n{P0,P1,P2,,Pn1}。求序列 Pmod 998244353 下的最短线性递推式。注意需要保证最短线性递推式长度 n/2(意思是:如果这个序列不是线性递推的,会返回长度为 n/2 的假的递推式)。

poly BM(poly a);

solution

P 下标从 1 开始,递推式下标从 0 开始。记 P[1...i] 的最短线性递推式为 Ri,特别地有 R0={}。已知 R[0...(i1)],怎么求 Ri

首先计算一个 Δ(Ri1,i)=deli=j=0Ri1,jaij1ai。如果 Δ(Ri1,i)=0 那么说明 Ri=Ri1,结束。否则就失配了。第一种情况是 Ri1={},只需要使得 Rii0,即强制使得前面这 i 个数是初始给定不用管的。

第二种情况,我们实际上是要考虑找出一个 R={R0,R1,,Rk1} 使得对于 k<p<i 都有 j=0k1Rjapj1=0 而且唯独有 j=0k1Rjaij1=deli,然后 Ri=Ri1+R 就好了,这里加法是对应位相加。这个事情就很严重。解决方法是任选一个 w,然后回忆起 Rw 和当时的 delw+1=delta,构造 R={0,0,,0,deli/delta,deli/delta×Rw},后面这个 ×Rw 是这个序列去数乘上前面的数再接上前面的序列,然后有 iw2 个零在前面。然后我们验证一下它合法。

k=iw1+|Rw|。对于 k<p<i

Δ(R,p)+ap=j=0k1Rjapj1=delidelta(api+w+1j=0Rwapji+w)=delideltaΔ(Rw,pi+w+1).

因为 p<i,所以 pi+w+1w,根据定义,Δ(Rw,pi+w+1)=0。满足条件。

对于 p=i

Δ(R,i)+ai=j=0k1Rjaij1=delidelwΔ(Rw,w+1).

因为这里 Δ(Rw,w+1)=delw+1=delta,所以这玩意等于 deli,全对。

为了使得 Ri 最短,选使 R 最短的 w 即可。

code

具体代码有 0 下标细节等,可以在这里看一下。注意最终返回递推式下标从 1 开始。

poly BM(poly a) {
  poly ans, lst;
  int w = 0;
  mint delta = 0;
  for (int i = 0; i < a.size(); i++) {
    mint tmp = -a[i];
    for (int j = 0; j < ans.size(); j++) tmp += ans[j] * a[i - j - 1];
    if (tmp == 0) continue;
    if (ans.empty()) {
      w = i;
      delta = tmp;
      ans = vector<mint>(i + 1, 0);
    } else {
      auto now = ans;
      mint mul = -tmp / delta;
      if (ans.size() < lst.size() + i - w) ans.resize(lst.size() + i - w);
      ans[i - w - 1] -= mul;
      for (size_t j = 0; j < lst.size(); j++) ans[i - w + j] += lst[j] * mul;
      if (now.size() <= lst.size() + i - w) { // 注意此时无符号数溢出,注意移项
        w = i;
        lst = now;
        delta = tmp;
      }
    }
  }
  return ans << 1;
}

Bostan-Mori 算法(求分式第 n 项)

问题

给出多项式 F(x),G(x)n,求 [xn]F(x)/G(x)

template <class T>
mint divide_at(poly f, poly g, T n);

solution

[xn]F(x)G(x)=[xn]F(x)G(x)G(x)G(x).

因为 G(x)G(x) 是一个偶函数(函数 H(x) 为偶函数当且仅当 H(x)=H(x),如绝对值函数,如 H(x)=x2 等只有偶次项有值的多项式函数),所以它只有偶次项,不妨直接记作 v(x2)=G(x)G(x)。为了适应之,我们将 F(x)G(x) 按照奇偶次项分裂,分为 c0(x2)+xc1(x2)c0(x2) 是只拿偶次项,xc1(x) 是只拿奇次项,这样,c0(x) 就是很连续的东西,c1(x) 也是。

=[xn]c0(x2)+xc1(x2)v(x2)=[xn]c0(x2)v(x2)+[xn1]c1(x2)v(x2)={[xn/2]c0(x)v(x),2n,[x(n1)/2]c1(x)v(x),2n.

于是 n 的规模减少一半,这样只需要 O(logn) 次操作就能到达 n=0 的情况,答案为 F(0)/G(0)。同时每次减少问题规模,多项式 F(x),G(x) 的长度都不变。这样时间复杂度为 O(mlogmlogn) 其中 mmax(degF(x),degG(x))

常系数齐次线性递推

问题

求一个满足 k 阶齐次线性递推数列 ai 的第 n 项,即:

an=i=1kfi×ani

给出的是 a0,a1,,ak1f1,f2,,fk 注意下标。

template <class T>
mint linear_rec(poly a, poly f, T n);

solution

我们只需要构造 F(x),G(x) 使得 [xn]F(x)G(x)=an 即可。以下记 Fi=[xi]F(x)

为了使得 F(x) 的项数足够小,考虑钦定 Fi=0ik 时。那么根据定义:

F(x)=G(x)a(x)0=j=0iGjaijG0ai=j=1iGjaij

使得 G0=1,Gi=fi(1ik),其他项都是零。

又因为 F(x)=G(x)a(x)modxk,于是暴力计算 F(x)

然后应用 Bostan-Mori 算法即可。即我们想要算的是以下东西:

[xn]a(1f)modxk1f.

注意 f 下标从 1 开始。

你说的对所以代码在哪里

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, enable_if_t<is_integral<T>::value>* = nullptr>
  modint(T x) {
    x %= mod;
    if (x < 0) x += mod;
    v = x;
  }
  modint(const string& str) {
    v = 0;
    size_t i = 0;
    if (str.front() == '-') i += 1;
    while (i < str.size()) {
      assert(isdigit(str[i]));
      v = (v * 10ull % umod + str[i] - '0') % umod;
      i += 1;
    }
    if (str.front() == '-' && v) v = umod - v;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint& self) { return self.v; }
  friend istream& operator>>(istream& is, modint& self) {
    string str;
    is >> str;
    self = str;
    return is;
  }
  friend ostream& operator<<(ostream& os, const modint& self) {
    return os << raw(self);
  }
  modint& operator+=(const modint& rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint& operator-=(const modint& rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint& operator*=(const modint& rhs) {
    v = static_cast<unsigned>(1ull * v * rhs.v % umod);
    return *this;
  }
  modint& operator/=(const modint& rhs) {
    static constexpr size_t ilim = 1 << 20;
    static modint inv[ilim + 10];
    static int sz = 0;
    assert(rhs.v);
    if (rhs.v > ilim) return *this *= qpow(rhs, mod - 2);
    if (!sz) inv[1] = sz = 1;
    while (sz < (int)rhs.v) {
      for (int i = sz + 1; i <= sz << 1; i++) inv[i] = -mod / i * inv[mod % i];
      sz <<= 1;
    }
    return *this *= inv[rhs.v];
  }
  template <class T>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
  friend bool operator==(const modint& lhs, const modint& rhs) {
    return lhs.v == rhs.v;
  }
  friend bool operator!=(const modint& lhs, const modint& rhs) {
    return lhs.v != rhs.v;
  }
};
typedef modint<998244353> mint;
int glim(const int& x) { return 1 << (32 - __builtin_clz(x - 1)); }
int bitctz(const int& x) { return __builtin_ctz(x); }
struct poly : vector<mint> {
  poly() {}
  explicit poly(int n) : vector<mint>(n) {}
  poly(const vector<mint>& vec) : vector<mint>(vec) {}
  poly(initializer_list<mint> il) : vector<mint>(il) {}
  mint operator()(const mint& x) const;
  poly& cut(int lim);
  void ntt(int op);
};
void print(const poly& a) {
  for (size_t i = 0; i < a.size(); i++) debug("%d, ", raw(a[i]));
  debug("\n");
}
istream& operator>>(istream& is, poly& a) {
  for (auto& x : a) is >> x;
  return is;
}
ostream& operator<<(ostream& os, const poly& a) {
  bool flag = false;
  for (auto& x : a) {
    if (flag)
      os << " ";
    else
      flag = true;
    os << x;
  }
  return os;
}
mint poly::operator()(const mint& x) const {
  const auto& a = *this;
  mint res = 0;
  for (int i = (int)a.size() - 1; i >= 0; i--) {
    res = res * x + a[i];
  }
  return res;
}
poly& poly::cut(int lim) {
  resize(lim);
  return *this;
}
void poly::ntt(int op) {
  static bool wns_flag = false;
  static vector<mint> wns;
  if (!wns_flag) {
    wns_flag = true;
    for (int j = 1; j <= 23; j++) {
      wns.push_back(qpow(mint(3), raw(mint(-1)) >> j));
    }
  }
  vector<mint>& a = *this;
  int n = a.size();
  for (int i = 1, r = 0; i < n; i++) {
    r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
    if (i < r) std::swap(a[i], a[r]);
  }
  vector<mint> w(n);
  for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
    mint wn = wns[bitctz(k)];
    for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
    for (int i = 0; i < n; i += len) {
      for (int j = 0; j < k; j++) {
        mint x = a[i + j], y = a[i + j + k] * w[j];
        a[i + j] = x + y, a[i + j + k] = x - y;
      }
    }
  }
  if (op == -1) {
    mint iz = mint(1) / n;
    for (int i = 0; i < n; i++) a[i] *= iz;
    reverse(a.begin() + 1, a.end());
  }
}
poly concalc(int n, vector<poly> vec,
             const function<mint(vector<mint>)>& func) {
  int lim = glim(n);
  int m = vec.size();
  for (auto& f : vec) f.resize(lim), f.ntt(1);
  vector<mint> tmp(m);
  poly ret(lim);
  for (int i = 0; i < lim; i++) {
    for (int j = 0; j < m; j++) tmp[j] = vec[j][i];
    ret[i] = func(tmp);
  }
  ret.ntt(-1);
  return ret;
}
poly getInv(const poly& a, int lim) {
  poly b{1 / a[0]};
  for (int len = 2; len <= glim(lim); len <<= 1) {
    poly c = vector<mint>(a.begin(), a.begin() + min(len, (int)a.size()));
    b = concalc(len << 1, {b, c}, [](vector<mint> vec) {
          return vec[0] * (2 - vec[0] * vec[1]);
        }).cut(len);
  }
  return b.cut(lim);
}
poly operator+=(poly& a, const poly& b) {
  if (a.size() < b.size()) a.resize(b.size());
  for (size_t i = 0; i < b.size(); i++) a[i] += b[i];
  return a;
}
poly operator-=(poly& a, const poly& b) {
  if (a.size() < b.size()) a.resize(b.size());
  for (size_t i = 0; i < b.size(); i++) a[i] -= b[i];
  return a;
}
poly operator*=(poly& a, const mint& k) {
  if (k == 1) return a;
  for (size_t i = 0; i < a.size(); i++) a[i] *= k;
  return a;
}
poly operator/=(poly& a, const mint& k) { return a *= 1 / k; }
poly operator<<=(poly& a, const int& k) {
  // mnltiple by x^k
  a.insert(a.begin(), k, 0);
  return a;
}
poly operator>>=(poly& a, const int& k) {
  // divide by x^k
  a.erase(a.begin(), a.begin() + min(k, (int)a.size()));
  return a;
}
poly operator*(const poly& a, const poly& b) {
  if (a.empty() || b.empty()) return {};
  int rlen = a.size() + b.size() - 1;
  int len = glim(rlen);
  if (1ull * a.size() * b.size() <= 1ull * len * bitctz(len)) {
    poly ret(rlen);
    for (size_t i = 0; i < a.size(); i++)
      for (size_t j = 0; j < b.size(); j++) ret[i + j] += a[i] * b[j];
    return ret;
  } else {
    return concalc(len, {a, b},
                   [](vector<mint> vec) { return vec[0] * vec[1]; })
        .cut(rlen);
  }
}
poly operator/(poly a, poly b) {
  if (a.size() < b.size()) return {};
  int rlen = a.size() - b.size() + 1;
  reverse(a.begin(), a.end());
  reverse(b.begin(), b.end());
  a = (a * getInv(b, rlen)).cut(rlen);
  reverse(a.begin(), a.end());
  return a;
}
poly operator-(poly a, const poly& b) { return a -= b; }
poly operator%(const poly& a, const poly& b) {
  return (a - (a / b) * b).cut(b.size() - 1);
}
poly operator*=(poly& a, const poly& b) { return a = a * b; }
poly operator/=(poly& a, const poly& b) { return a = a / b; }
poly operator%=(poly& a, const poly& b) { return a = a % b; }
poly operator+(poly a, const poly& b) { return a += b; }
poly operator*(poly a, const mint& k) { return a *= k; }
poly operator*(const mint& k, poly a) { return a *= k; }
poly operator/(poly a, const mint& k) { return a /= k; }
poly operator<<(poly a, const int& k) { return a <<= k; }
poly operator>>(poly a, const int& k) { return a >>= k; }
poly getDev(poly a) {
  a >>= 1;
  for (size_t i = 1; i < a.size(); i++) a[i] *= i + 1;
  return a;
}
poly getInt(poly a) {
  a <<= 1;
  for (size_t i = 1; i < a.size(); i++) a[i] /= i;
  return a;
}
poly getLn(const poly& a, int lim) {
  assert(a[0] == 1);
  return getInt(getDev(a) * getInv(a, lim)).cut(lim);
}
poly getExp(const poly& a, int lim) {
  assert(a[0] == 0);
  poly b{1};
  for (int len = 2; len <= glim(lim); len <<= 1) {
    poly c = vector<mint>(a.begin(), a.begin() + min(len, (int)a.size()));
    b = concalc(len << 1, {b, getLn(b, len), c}, [](vector<mint> vec) {
          return vec[0] * (1 - vec[1] + vec[2]);
        }).cut(len);
  }
  return b.cut(lim);
}
poly qpow(const poly& a, string k, int lim) {
  size_t i = 0;
  while (i < a.size() && a[i] == 0) i += 1;
  if (i == a.size() || (i > 0 && k.size() >= 9) ||
      1ull * i * raw(mint(k)) >= 1ull * lim)
    return poly(lim);
  lim -= i * raw(mint(k));
  return getExp(getLn(a / a[i] >> i, lim) * k, lim) *
             qpow(a[i], raw(modint<mint::mod - 1>(k)))
         << i * raw(mint(k));
}
poly qpow(const poly& a, LL k, int lim) {
  size_t i = 0;
  while (i < a.size() && a[i] == 0) i += 1;
  if (i == a.size() || (i > 0 && k >= 1e9) ||
      1ull * i * k >= 1ull * lim)
    return poly(lim);
  lim -= i * k;
  return getExp(getLn(a / a[i] >> i, lim) * k, lim) *
             qpow(a[i], raw(modint<mint::mod - 1>(k)))
         << i * k;
}
mint sqrt(const mint& c) {
  static const auto check = [](mint c) {
    return qpow(c, (mint::mod - 1) >> 1) == 1;
  };
  if (raw(c) <= 1) return 1;
  if (!check(c)) throw "No solution!";
  static mt19937 rng{random_device{}()};
  mint a = rng();
  while (check(a * a - c)) a = rng();
  typedef pair<mint, mint> number;
  const auto mul = [=](number x, number y) {
    return make_pair(x.first * y.first + x.second * y.second * (a * a - c),
                     x.first * y.second + x.second * y.first);
  };
  const auto qpow = [=](number a, int b) {
    number r = {1, 0};
    for (; b; b >>= 1, a = mul(a, a))
      if (b & 1) r = mul(r, a);
    return r;
  };
  mint ret = qpow({a, 1}, (mint::mod + 1) >> 1).first;
  return min(raw(ret), raw(-ret));
}
poly getSqrt(const poly& a, int lim) {
  poly b{sqrt(a[0])};
  for (int len = 2; len <= glim(lim); len <<= 1) {
    poly c = vector<mint>(a.begin(), a.begin() + min(len, (int)a.size()));
    b = (c * getInv(b * 2, len) + b / 2).cut(len);
  }
  return b.cut(lim);
}
template <class T>
mint divide_at(poly f, poly g, T n) {
  for (; n; n >>= 1) {
    poly r = g;
    for (size_t i = 1; i < r.size(); i += 2) r[i] *= -1;
    f *= r;
    g *= r;
    for (size_t i = n & 1; i < f.size(); i += 2) f[i >> 1] = f[i];
    f.resize((f.size() + 1) >> 1);
    for (size_t i = 0; i < g.size(); i += 2) g[i >> 1] = g[i];
    g.resize((g.size() + 1) >> 1);
  }
  return f.empty() ? 0 : f[0] / g[0];
}
template <class T>
mint linear_rec(poly a, poly f, T n) {
  // a[n] = sum_i f[i] * a[n - i]
  a.resize(f.size() - 1);
  f = poly{1} - f;
  poly g = a * f;
  g.resize(a.size());
  return divide_at(g, f, n);
}
poly BM(poly a) {
  poly ans, lst;
  int w = 0;
  mint delta = 0;
  for (size_t i = 0; i < a.size(); i++) {
    mint tmp = -a[i];
    for (size_t j = 0; j < ans.size(); j++) tmp += ans[j] * a[i - j - 1];
    if (tmp == 0) continue;
    if (ans.empty()) {
      w = i;
      delta = tmp;
      ans = vector<mint>(i + 1, 0);
    } else {
      auto now = ans;
      mint mul = -tmp / delta;
      if (ans.size() < lst.size() + i - w) ans.resize(lst.size() + i - w);
      ans[i - w - 1] -= mul;
      for (size_t j = 0; j < lst.size(); j++) ans[i - w + j] += lst[j] * mul;
      if (now.size() <= lst.size() + i - w) {
        w = i;
        lst = now;
        delta = tmp;
      }
    }
  }
  return ans << 1;
}
poly lagrange(const vector<pair<mint, mint>>& a) {
  poly ans(a.size()), product{1};
  for (size_t i = 0; i < a.size(); i++) {
    product *= poly{-a[i].first, 1};
  }
  auto divide2 = [&](poly a, mint b) {
    poly res(a.size() - 1);
    for (size_t i = (int)a.size() - 1; i >= 1; i--) {
      res[i - 1] = a[i];
      a[i - 1] -= a[i] * b;
    }
    return res;
  };
  for (size_t i = 0; i < a.size(); i++) {
    mint denos = 1;
    for (size_t j = 0; j < a.size(); j++) {
      if (i != j) denos *= a[i].first - a[j].first;
    }
    poly numes = divide2(product, -a[i].first);
    ans += a[i].second / denos * numes;
  }
  return ans;
}

posted @   caijianhong  阅读(301)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示