d-finite 与 ODE 自动机

机械化求解整式递推,又称 ODE 自动机

最后还是要自己写一个(
用别人的不放心!
你就放心吧,我会写使用说明的(

定义

对函数 \(y(x)\),方程

\[\sum_{i = 0}^n a_i(x) y^{(i)}(x) = C(x) \]

为其的一个 \(n\) 阶线性微分方程。若 \(C(x) = 0\),则其为一个齐次的线性微分方程,并称满足这一方程的函数是微分有限(d-finite)的。若 \(\forall i, \text{deg }a_i(x) = 0\),则其是常系数的,反之则是变系数的。若对应函数是单变量的,就称对应的线性微分方程为常微分方程(ordinary differential equation,ODE)。

对数列 \(\{a_n\}\),方程

\[\sum_{i = 0}^d p_i(n) a_{n - i} = 0 \]

为其的一个 \(d\) 阶整式递推,其中 \(\forall p_i(n)\) 为多项式,并称满足这一方程的函数是可多项式递推(P-recursive)的。

可以证明:

  1. 令数列 \(\{a_n\}\) 的生成函数为 \(A(x)\),那么 \(A(x)\) d-finite,当且仅当 \(\{a_n\}\) P-recursive。

  2. \(f(x), g(x)\) d-finite,则对任意常数 \(c_1, c_2\)\(c_1 f(x) + c_2 g(x)\) d-finite。

  3. \(f(x), g(x)\) d-finite,则 \(f(x)g(x)\) d-finite。

  4. \(f(x)\) d-finite,\(g(x)\) 代数,则 \(f(g(x))\) d-finite。其中,\(g(x)\) 代数,当且仅当存在不可约方程(不可被分解为两个次数较低的多项式之乘积的多项式与 \(0\) 形成的等式)

    \[\sum_{i = 0}^n a_i(x) g^i(x) = 0 \]

    恒成立。

基础函数的构造

  1. \(f(x) = x^{\alpha}\)
    考察 \(x f'(x) = x\times \alpha x^{\alpha - 1} = \alpha x^{\alpha} = \alpha f(x)\),因此有

    \[\alpha f - x f' = 0 \]

  2. \(f(x) = e^x\)
    考察 \(f'(x) = e^x = f(x)\),因此有

    \[f - f' = 0 \]

  3. \(f(x) = {}_pF_q(a_1,\dots, a_p; b_1, \dots, b_q; x)\)
    \(\vartheta = xD\),则有

    \[(\vartheta + a_1)\cdots (\vartheta + a_p) f = D (\vartheta + b_1 - 1)\cdots (\vartheta + b_q - 1) f \]

  4. \(f(x) = c\times g(x)\),其中 \(c\) 为与 \(x\) 无关的常数。则 \(f(x)\) 满足的微分方程和 \(g(x)\) 相同。
    写这个的原因是老是忘记常数咋处理。记住:微分方程不管常数,常数无论如何都能消掉,递推时它们的信息由初值给出。

如何实现?

一般而言,我们想要求得 ODE 的函数都是由上方基础函数加减乘除复合得来,因此我们只需要让机器帮我们做两个 ODE 相加、相乘,ODE 和代数函数复合,就基本上可以解决我们面对的问题。

受到我没看过的论文证明的启发,我们应当维护对应 d-finite 函数 \(f\) 的各阶导数 \(f, f', f'', \dots\) 组成的线性空间。

下面令系数都在模 \(p\) 意义下进行,即属于域 \(F_p\)。令我们要处理的两个 d-finite 函数为 \(f(x), g(x)\),并令

\[f^{(n)}(x) = \sum_{i = 0}^{n - 1} p_i(x) f^{(i)}(x), \qquad g^{(m)}(x) = \sum_{i = 0}^{m - 1} q_i(x) g^{(i)}(x) \]

其中 \(\forall p_i(x), q_i(x) \in F_p(x)\),即系数均属于 \(F_p\) 的有理分式集。此外,不妨令 \(n\le m\)

\(f(x) + g(x)\)

考察 \((f + g)^{(k)}\)。我们需要的就是用 \(f, \dots, f^{(n-1)}, g, \dots, g^{(m-1)}\) 表出每个 \((f + g)^{(k)}\)

\(k < n\) 的时候,知道 \((f + g)^{(k)} = f^{(k)} + g^{(k)}\)
\(k = n\) 的时候,\((f + g)^{(n)} = f^{(n)} + g^{(n)}\),而 \(f^{(n)}\) 可使用微分方程降次。

从而当 \(k \ge n\) 时,若

\[(f + g)^{(k)} = \sum_{i = 0}^{n - 1} a_{k, i}(x) f^{(i)} + \sum_{i = 0}^{n - 1}(x) b_{k, i} g^{(i)} \]

\[\begin{aligned} (f + g)^{(k + 1)} & = \left(\sum_{i = 0}^{n - 1} a_{k, i}(x) f^{(i)} + \sum_{i = 0}^{n - 1}(x) b_{k, i} g^{(i)}\right)' \\ & = \sum_{i = 0}^{n - 1} \left(a_{k, i}'(x) f^{(i)} + a_{k, i}(x) f^{(i + 1)}\right) + \sum_{i = 0}^{n - 1}(x) \left(b_{k, i}'(x) g^{(i)} + b_{k,i}(x) g^{(i+1)}\right) \end{aligned}\]

随后只需将 \(f^{(n)}\) 用微分方程降次即可。这样由归纳法可以知道,我们总能将每个 \((f + g)^{(k)}\)\(f, \dots, f^{(n-1)}, g, \dots, g^{(m-1)}\) 表出。由于向量 \([f, \dots, f^{(n-1)}, g, \dots, g^{(m-1)}] ^{\mathsf T}\) 只有 \(n + m\) 维,则 \((f + g)^{(n + m)}\) 必定能由 \(f + g, (f + g)^{(1)}, \dots, (f + g)^{(n + m - 1)}\) 表出。

因此我们可以用与维护线性基类似的思想,一个个插入 \((f + g)^{(k)}\),最后必定可以得到一个 \((f + g)^{(k_0)}\) 使得其可以被线性表出,且 \(k_0 \le n + m\)

\(f(x)g(x)\)

考察 \((fg)^{(k)}\)\((fg)^{(k + 1)}\)。若

\[(fg)^{(k)} = \sum_{i = 0}^{n - 1} \sum_{j = 0}^{m - 1} a_{i,j}(x) f^{(i)} g^{(j)} \]

\(\left(a_{i,j}(x) f^{(i)} g^{(j)}\right)' = a_{i,j}'(x) f^{(i)} g^{(j)} + a_{i,j}(x) f^{(i + 1)} g^{(j)} + a_{i,j}(x) f^{(i)} g^{(j + 1)}\),并使用微分方程降次可以知道

\[(fg)^{(k + 1)} = \sum_{i = 0}^{n - 1} \sum_{j = 0}^{m - 1} \left(a_{i,j}(x) f^{(i)} g^{(j)}\right)' \]

也符合 \((fg)^{(k)}\) 的形式。这样由归纳法可以知道,我们总能将每个 \((f g)^{(k)}\)\(fg, f^{(1)}g, \dots, f^{(n-1)}g, \dots, f^{(n-2)}g^{(m-1)}, f^{(n-1)}g^{(m-1)}\) 表出。由于以上的分量只有 \(n m\) 维,则 \((f g)^{(n m)}\) 必定能由 \(f g, (f g)^{(1)}, \dots, (f g)^{(n m - 1)}\) 表出。

因此我们可以用与维护线性基类似的思想,一个个插入 \((fg)^{(k)}\),最后必定可以得到一个 \((fg)^{(k_0)}\) 使得其可以被线性表出,且 \(k_0 \le nm\)

\(f \circ g\)

这里 \(g(x)\) 符合一个代数方程 \(P(g) = 0\)。其系数 \(\in F_p(x)\)

\(\text{deg } P = 1\),则 \(g(x)\)\(\in F_p(x)\)。这时令

\[(f\circ g)^{(k)} = \sum_{i = 0}^{n - 1} a_{k,i}(x) \left(f^{(i)} \circ g\right) \]

考察 \(\left(a_{k,i}(x)\left(f^{(i)} \circ g\right)\right)' = a_{k,i}'(x) \left(f^{(i)} \circ g\right) + a_{k,i}(x) \left(f^{(i + 1)} \circ g\right) g'\),当 \(i + 1 = n\) 时可以用微分方程降次,即

\[f^{(n)}\circ g = \sum_{i = 0}^{n - 1} \left(p_i\circ g\right)\left(f^{(i)} \circ g\right) \]

\(\text{deg } P > 1\)\(g(x)\) 不再具有简单的有理分式形式。上面的过程中有两个点无法直接解决:求解 \(g'\) 和求解每个 \(p_i\circ g\)

若我们能将 \(g'\) 表示成关于 \(g\) 的某个整式 \(h(g)\),我们就能直接将其并入前一部分的 \(\left(f^{(i + 1)} \circ g\right)\),因此下面就是要用 \(g\) 表出 \(g'\)

由于 \(P(g) = 0\),其对 \(x\) 求导得到的应为零函数。考察某项 \(a_i(x) g^i(x)\),其导数为 \(\left(a_i(x) g^i(x)\right)' = a_i'(x) g^i(x) + a_i(x) i g^{i - 1}(x) g'(x)\),前半部分求和得到 \(\dfrac{\partial P(t)}{\partial x} \circ g\),后半部分求和得到 \(\left(\dfrac{\partial P(t)}{\partial t}\circ g\right) g'\),其中 \(t\) 为占位元,与 \(x\) 无关。化简得到

\[g' = -\dfrac{\partial_x P}{\partial_g P} \]

这里视 \(P(g, x)\) 为二元代数函数,\(g,x\) 彼此无关。

为降次,把 \(g'\) (对应的 \(h(g)\))对 \(P(g)\) 取模。要算的就是同余方程

\[y\partial_g P \equiv -\partial_x P \pmod{P} \]

的解 \(y\)。可以证明这解必定存在,使用多项式扩展欧几里得即可。

\(p_i\circ g\),也可以对 \(P(g)\) 取模来降次。\(p_i(x) \in F_p(x)\),因此 \(p_i\) 分母 \(\in F_p[x]\)。而由于 \(P(u)\) 代数,其在系数所属的域 \(F_p[x]\) 上无法分解,故 \(P(u)\) 的根都在 \(F_p(x) / F_p[x]\) 上,其和 \(p_i(x)\) 的分母不存在公因式,可以求逆。

实现

实现了一个,写了 25k,全爆炸了。
以 EI 板子为基础构造代码。

一些细节。

有时候实现线性递推并不需要多项式模数,但由于一道题一般只有一个模数,可以和多项式板子共用一个模数。
需要系数在 \(F_p\) 上的多项式 opoly 类维护多项式。
需要 \(F_p(x)\) 上的有理分式,定义 pfrac 类维护。
需要系数在 \(F_(x)\) 上的多项式类维护 ODE。令对应的函数 \(f(x)\) 满足

\[\sum_{i = 0}^{n} p_i(x) f^{(i)} = 0 \]

并用 vector<opoly> 维护多项式列 \(p_i(x)\)。这里如果使用 \(f^{(n)} = \dots\) 的表达形式,则需要维护有理分式列,相对降低直观性。
最后就是用 ODE 返回整式递推,以及用整式递推 \(O(n)\) 推导了。

为了不让过长的前导零降低效率,维护 shrink(T& x) 删除 \(x\) 的前导零。
为统一接口,可以封装一个 \(F_p\) 内元素的类 Z,并在底层采用取模优化。由于此处常数瓶颈并非 ntt 而是线性的取模,预计这样的优化会减小常数。
注意:使用单个常数初始化 opoly 会得到对应常数长度的零多项式,得到常单项式 \(c\) 需要使用语法 opoly(1, c) 或直接使用 {Z(c)}。在隐性类型转换时(如多项式除以 int 等处)尤其要小心。

这里也暴露了一个问题:多项式本体 poly 的系数是 u32 类型的,并非 opolyZ 类型。因此,需要在推导的某个阶段逐渐将 opoly 逐步更换为 poly,便于整式递推的获得与结果的处理。这个阶段难以确定,因此不妨定义成员函数 trans_prec 返回 poly 盛放的整式递推。
目前认为,将 ODE 处理的全部内容(包括 \(+\times\circ\),得到整式递推,执行 \(O(n)\) 递推)均封装为 ODE 的成员函数会相对增加直观性。

使用方法

先把板子放在这里。需要使用多项式类 poly,但功能似乎不是很大。

code
struct FastMod {
   int m; ll b;
   inline void init(int _m = 1) { m = _m; if (m == 0) m = 1; b = ((lll)1 << 64) / m; } 
   FastMod(int _m = 1) { init(_m); }
   inline int operator()(ll a) { ll q = ((lll)a * b) >> 64; a -= q * m; if (a >= m) a -= m; return a; }
} Mod(mod);
struct Z {
   u32 v;
   Z(u32 v = 0) : v(v) {}
   Z(int v) : v(Norm(Mod(v) + mod)) { }
   Z(long long v) : v(Norm(Mod(v) + mod)) { }
   inline friend Z operator+(const Z& lhs, const Z &rhs) { return Norm(lhs.v + rhs.v); }
   inline friend Z operator+(const Z& lhs, const int &rhs) { return Norm(lhs.v + rhs); }
   inline friend Z operator+(const int& lhs, const Z& rhs) { return Norm(lhs + rhs.v); }
   inline friend Z operator-(const Z& lhs, const Z &rhs) { return Norm(lhs.v + mod - rhs.v); }
   inline friend Z operator-(const Z& lhs, const int &rhs) { return Norm(lhs.v - rhs + mod); }
   inline friend Z operator-(const int& lhs, const Z& rhs) { return Norm(lhs - rhs.v + mod); }
   inline Z operator-() const { return Norm(mod - v); }
   inline Z inv() const { return qp(*this, mod - 2); };
   inline friend Z operator*(const Z& lhs, const Z& rhs) { return Mod(1ll * lhs.v * rhs.v); }
   inline friend Z operator*(const Z& lhs, const int &rhs) { return Mod(1ll * lhs.v * rhs); }
   inline friend Z operator*(const int& lhs, const Z& rhs) { return Mod(1ll * lhs * rhs.v); }
   inline friend Z operator/(const Z& lhs, const Z &rhs) { return Mod(1ll * lhs.v * rhs.inv().v); }
   inline friend Z operator/(const Z& lhs, const int &rhs) { return Mod(1ll * lhs.v * qp(rhs, mod - 2)); }
   inline friend Z operator/(const int& lhs, const Z& rhs) { return Mod(1ll * lhs * rhs.inv().v); }
   operator u32() const { return v; }
   inline friend ostream &operator<< (ostream &out, const Z &x) { return out << x.v; }
};
inline Z &operator+=(Z &lhs, const Z &rhs) { return lhs = lhs + rhs; }
inline Z &operator-=(Z &lhs, const Z &rhs) { return lhs = lhs - rhs; }
inline Z &operator*=(Z &lhs, const Z &rhs) { return lhs = lhs * rhs; }
inline Z &operator/=(Z &lhs, const Z &rhs) { return lhs = lhs / rhs; }

struct opoly : vector<Z> {
   opoly(const int& n = 1, const Z& val = 0) : vector(n, val) {}
   opoly(const initializer_list<value_type> &il) : vector(il) {}
   opoly(const vector<Z> &il) : vector(il) {}
   inline int degree() const { return (int) size() - 1; }
   inline bool shrink() { int k = size(); while (k && !at(k - 1)) --k; resize(k); return k; }
   inline void redegree(const size_t& n) { resize(n + 1); }
   inline opoly operator-() const { opoly ret(size()); for (int i = 0; i < size(); ++i) ret[i] = -at(i); return ret; }
   inline operator poly() const { poly ret(size()); for (int i = 0; i < (int)size(); ++ i) ret[i] = u32(at(i)); return ret; }
   inline friend opoly operator*(const opoly &a, const opoly &b) {
      int n = a.degree(), m = b.degree();
      if (n == -1 || m == -1) return opoly();
      opoly c(n + m + 1);
      for (int i = 0; i <= n + m; ++i)
         for (int j = max(0, i - m); j <= min(i, n); ++j)
            c[i] += a[j] * b[i - j];
      return c;
   }
   inline friend opoly operator*(const opoly &a, const Z &z) { opoly c(a); for (Z &x : c) x *= z; return c; }
   inline friend opoly operator*(const Z &z, const opoly &a) { opoly c(a); for (Z &x : c) x *= z; return c; }
   inline friend opoly operator+(const opoly &a, const opoly &b) {
      opoly c(max(a.size(), b.size()));
      for (int i = 0; i < a.size(); ++i) c[i] += a[i];
      for (int i = 0; i < b.size(); ++i) c[i] += b[i];
      return c;
   }
   inline friend opoly operator-(const opoly &a, const opoly &b) { return a + -b; }
   inline friend opoly operator+(const opoly &a, const Z &z) { opoly c(a); c[0] += z; return c; }
   inline friend opoly operator+(const Z &z, const opoly &a) { opoly c(a); c[0] += z; return c; }
   inline friend opoly operator-(const opoly &a, const Z &z) { opoly c(a); c[0] -= z; return c; }
   inline friend opoly operator-(const Z &z, const opoly &a) { opoly c(a); c[0] -= z; return c; }
   inline friend bool operator==(const opoly &a, const opoly &b) { 
      if (a.size() != b.size()) return false; 
      for (int i = 0; i < a.size(); ++ i)  if (a[i] != b[i]) return false;
      return true;
   }
   opoly deri() const {
      if (empty()) return *this;
      opoly a(*this);
      for (int i = 1; i < a.size(); ++i) a[i - 1] = a[i] * i;
      a.pop_back();
      return a;
   }
   Z eval(const Z &z) const {
      Z v = 0;
      for (int i = degree(); i >= 0; --i) v = v * z + at(i);
      return v;
   }
};
template <typename _vi>
inline opoly topoly(const _vi& a) {
   opoly ret(a.size());
   for (int i = 0; i < a.size(); ++ i) 
      ret[i].v = a[i];
   return ret;
}
inline opoly gcd(opoly a, opoly b) {
   if (!a.shrink()) return b;
   if (!b.shrink()) return a;
   if (a.size() < b.size()) swap(a, b);
   while (b.shrink()) {
      Z in = b.back().inv();
      for (Z &x : b) x *= in;
      int n = a.degree(), m = b.degree();
      for (int i = n; i >= m; --i) {
         for (int j = 1; j <= m; ++j)
            a[i - j] -= a[i] * b[m - j];
         a[i] = 0;
      } swap(a, b);
   } return a.shrink(), a;
}
inline opoly div(const opoly& _a, const opoly& _b) {
   opoly a(_a), b(_b);
   Z in = b.back().inv();
   for (Z &x : b) x *= in;
   int n = a.degree(), m = b.degree();
   opoly ret(n - m + 1);
   for (int i = n; i >= m; --i) {
      ret[i - m] = a[i] * b[m];
      for (int j = 1; j <= m; ++j)
         a[i - j] -= a[i] * b[m - j];
   } for (Z &x : ret) x = x * in;
   return ret;
}
inline opoly operator/ (const opoly& _a, const opoly& _b) { return div(_a, _b); }
inline opoly& operator/= (opoly& _a, const opoly& _b) { _a = div(_a, _b); return _a; }

struct pfrac {
	opoly x, y;
	pfrac(const opoly &x = opoly(), const opoly &y = {Z(1)}) : x(x), y(y) {}
	void shrink() {
		y.shrink();
		if (!x.shrink()) { y = opoly{Z(1)}; return; }
		opoly g = gcd(x, y);
		x /= g, y /= g;
	}
	inline pfrac operator+(const pfrac &rhs) const { pfrac ret = pfrac(x * rhs.y + y * rhs.x, y * rhs.y); ret.shrink(); return ret; }
	inline pfrac operator-() const { pfrac ret = pfrac(-x, y); ret.shrink(); return ret; }
	inline pfrac operator-(const pfrac &rhs) const { pfrac ret = *this + -rhs; ret.shrink(); return ret; }
	inline pfrac operator*(const pfrac &rhs) const { pfrac ret = pfrac(x * rhs.x, y * rhs.y); ret.shrink(); return ret; }
	// inline pfrac operator*(const Z &rhs) const { pfrac ret = pfrac(x * rhs, y * rhs); ret.shrink(); return ret; } // ?
	inline pfrac inv() const { pfrac ret = pfrac(y, x); ret.shrink(); return ret; }
	inline pfrac operator/(const pfrac &rhs) const { pfrac ret = *this * rhs.inv(); ret.shrink(); return ret; }
	bool operator==(const pfrac &rhs) const { return x * rhs.y == y * rhs.x; }
	bool operator!=(const pfrac &rhs) const { return !operator==(rhs); }
	pfrac deri() const { pfrac ret = pfrac(x.deri() * y - y.deri() * x, y * y); ret.shrink(); return ret; }
};

struct Q_Basis {
   int dim, id;
   vector<vector<pfrac> > basis, augment;
   Q_Basis(int dim) : dim(dim), id(), basis(dim), augment(dim) {}
   vector<pfrac> insert(vector<pfrac> vec) {
      vector<pfrac> tmp(dim + 1);
      tmp[id++] = pfrac({Z(1)});
      for (int i = 0; i < dim; ++i) {
         if (vec[i] != pfrac()) {
            if (basis[i].empty()) {
               for (int j = i + 1; j < dim; ++j) {
                  vec[j] = vec[j] / vec[i];
                  vec[j].shrink();
               }
               for (int j = 0; j < id; ++j) {
                  tmp[j] = tmp[j] / vec[i];
                  tmp[j].shrink();
               }
               vec[i] = pfrac({Z(1)});
               basis[i] = vec;
               augment[i] = tmp;
               return {};
            } else {
               for (int j = i + 1; j < dim; ++j)
                  vec[j] = vec[j] - vec[i] * basis[i][j];
               for (int j = 0; j < id; ++j)
                  tmp[j] = tmp[j] - vec[i] * augment[i][j];
               vec[i] = pfrac();
            }
         }
      } return tmp;
   }
};

using PRec = vector<opoly>;
using ODE_base = vector<opoly>;

struct ODE {
   ODE_base ode;
   ODE(const int& n = 1, const opoly& val = {0}) : ode(n, val) {}
   ODE(const initializer_list<opoly> &il) : ode(il) {}
   inline int size() const { return (int)ode.size(); }
   inline int degree() const { return size() - 1; }
   inline opoly& operator[](int x) { return ode[x]; }
   inline opoly operator[](int x) const { return ode[x]; }
   inline ODE& resize(int x) { return ode.resize(x), *this; }
   inline ODE& redegree(int x) { return ode.resize(x + 1), *this; }
   inline ODE_base::iterator begin() { return ode.begin(); }
   inline ODE_base::iterator end() { return ode.end(); }
   void shrink() {
      opoly g = ode[0];
      for (opoly &x : ode) x.shrink(), g = gcd(g, x);
      for (opoly &x : ode) if (!x.empty()) x = div(x, g);
   }
   ODE deri() const {
      ODE _ode(*this);
      _ode.resize(_ode.size() + 1);
      for (int i = (int) _ode.size() - 2; i >= 0; --i) {
         _ode[i + 1] = _ode[i + 1] + _ode[i];
         _ode[i] = _ode[i].deri();
      } return _ode;
   }
   ODE theta() const {
      ODE _ode(deri());
      for (int i = 0; i < _ode.size(); ++i) _ode[i] = _ode[i] * opoly{Z(), Z(1)};
      return _ode;
   }

   PRec _rec;
   vector<Z> coef;
   inline PRec getPRec() {
      shrink();
      int tmp = numeric_limits<int>::max();
      int n = degree(), m = numeric_limits<int>::min();
      for (int i = 0; i <= n; ++i) {
         if (ode[i].empty()) continue;
         m = max(m, (int) ode[i].size() - 1 - i);
         int j = 0;
         while (ode[i][j] == 0) ++j;
         tmp = min(tmp, j - i);
      } m -= tmp;
      PRec rec(m + 1, opoly(n + 1)); opoly fall{Z(1)};
      for (int i = 0; i <= n; ++i) {
         opoly coef = fall;
         for (int j = 0; j < (int) ode[i].size() - i - tmp; ++j) {
            if (j + i + tmp >= 0) rec[j] = rec[j] + coef * ode[i][j + i + tmp];
            coef = div(coef * opoly{-Z(i + j), Z(1)}, opoly{-Z(j), Z(1)});
         } fall = fall * opoly{-Z(i), Z(1)};
      } 
      for (int i = 0; i <= m; ++i) rec[i].shrink();
      return _rec = rec;
   }
   int prf(int n) {
      coef.resize(n + 1);
      for (int i = 0; i <= n; ++i)
         coef[i] = _rec[0].eval(i);
      int r = 0;
      for (int i = n; i; --i)
         if (coef[i] == 0) {
            r = i;
            break;
         }
      return r;
   }
   opoly post(opoly init) {
      int m = init.size();
      auto invs = [&](const opoly &vec) {
         opoly prf(vec.size()), ret(vec.size());
         prf[0] = 1;
         for (int i = 1; i < vec.size(); ++i) prf[i] = prf[i - 1] * vec[i - 1];
         Z tot = accumulate(vec.begin(), vec.end(), Z(1), multiplies<Z>()).inv();
         for (int i = (int) vec.size() - 1; i >= 0; --i) ret[i] = tot * prf[i], tot *= vec[i];
         return ret;
      };
      auto nvs = invs(vector<Z>(coef.begin() + m, coef.end()));
      init.resize(coef.size());
      for (int i = m; i < coef.size(); ++i) {
         for (int j = 1; j < min(i + 1, (int) _rec.size()); ++j)
            init[i] += init[i - j] * _rec[j].eval(i);
         init[i] = init[i] * -nvs[i - m];
      } return init;
   }
   poly recur(const int& n, const poly& a0) {
      getPRec();
      prf(n);
      opoly ret(a0.size());
      for (int i = 0; i < a0.size(); ++ i) ret[i].v = a0[i];
      if (n <= ret.degree()) return ret.redegree(n), ret;
      return post(ret);
   }
   vector<poly> trans_poly() {
      vector<poly> ret(_rec.size(), poly());
      for (int i = 0; i < _rec.size(); ++ i) 
         _rec[i].shrink(), ret[i] = _rec[i];
      return ret;
   }

   inline friend ODE operator+(const ODE &op, const ODE &oq) {
      int n = op.degree(), m = oq.degree();
      Q_Basis basis(n + m);
      vector<pfrac> pd(n + 1), qd(m + 1);
      pd[0] = qd[0] = pfrac({Z(1)});
      for (int dim = 0; dim <= n + m; ++dim) {
         vector<pfrac> vec(n + m);
         copy(pd.begin(), pd.begin() + n, vec.begin());
         copy(qd.begin(), qd.begin() + m, vec.begin() + n);
         auto ret = basis.insert(vec);
         if (!ret.empty()) {
            ODE ode(dim + 1); opoly prod = {Z(1)};
            for (int i = 0; i < dim; ++i) prod = prod * ret[i].y;
            ode[dim] = prod;
            for (int i = 0; i < dim; ++i) ode[i] = ret[i].x * div(prod, ret[i].y);
            ode.shrink();
            return ode;
         } pd[n] = qd[m] = pfrac();
         for (int j = n - 1; j >= 0; --j) pd[j + 1] = pd[j + 1] + pd[j], pd[j] = pd[j].deri();
         for (int j = m - 1; j >= 0; --j) qd[j + 1] = qd[j + 1] + qd[j], qd[j] = qd[j].deri();
         for (int j = 0; j < n; ++j) pd[j] = pd[j] - pd[n] * op[j] / op[n], pd[j].shrink();
         for (int j = 0; j < m; ++j) qd[j] = qd[j] - qd[m] * oq[j] / oq[m], qd[j].shrink();
      }
   }

   inline friend ODE operator*(const ODE &op, const ODE &oq) {
      int n = op.size() - 1, m = oq.size() - 1;
      Q_Basis basis(n * m);
      vector<vector<pfrac> > p(n + 1, vector<pfrac>(m + 1));
      p[0][0] = pfrac({Z(1)});
      for (int dim = 0; dim <= n * m; ++dim) {
         vector<pfrac> vec(n * m);
         for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
               vec[i * m + j] = p[i][j];
         auto ret = basis.insert(vec);
         if (!ret.empty()) {
            ODE ode(dim + 1); opoly prod = {Z(1)};
            for (int i = 0; i < dim; ++i) prod = prod * ret[i].y;
            ode[dim] = prod;
            for (int i = 0; i < dim; ++i) ode[i] = ret[i].x * div(prod, ret[i].y);
            ode.shrink();
            return ode;
         }
         for (int i = 0; i < n; ++i) p[i][m] = pfrac();
         for (int j = 0; j < m; ++j) p[n][j] = pfrac();
         for (int i = n - 1; i >= 0; --i)
            for (int j = m - 1; j >= 0; --j) {
               p[i + 1][j] = p[i + 1][j] + p[i][j];
               p[i][j + 1] = p[i][j + 1] + p[i][j];
               p[i][j] = p[i][j].deri();
            }
         for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j) 
               p[i][j] = p[i][j] - p[n][j] * op[i] / op[n] - p[i][m] * oq[j] / oq[m], p[i][j].shrink();
      }
   }

   inline ODE composite(pfrac q) const {
      q.shrink();
      int n = ode.size() - 1;
      vector<vector<pfrac>> tri(n + 1, vector<pfrac>(n + 1));
      tri[0][0] = pfrac({Z(1)});
      pfrac d = q.deri();
      d.shrink();
      for (int i = 0; i < n; ++i) {
         for (int j = 0; j <= i; ++j) {
            tri[i + 1][j] = tri[i + 1][j] + tri[i][j].deri();
            tri[i + 1][j + 1] = tri[i][j] * d;
         } for (int j = 0; j <= i + 1; ++j) tri[i + 1][j].shrink();
      } vector<pfrac> vec(n + 1);
      for (int i = 0; i <= n; ++i) 
         for (int j = ode[i].degree(); j >= 0; --j) vec[i] = vec[i] * q + pfrac({ode[i][j]}), vec[i].shrink();
      for (int i = n; i >= 0; --i) {
         vec[i] = vec[i] / tri[i][i], vec[i].shrink();
         for (int j = 0; j < i; ++j)  vec[j] = vec[j] - vec[i] * tri[i][j];
      } opoly prod = opoly{Z(1)};
      for (int i = 0; i <= n; ++i) prod = prod * vec[i].y;
      ODE ret(n + 1);
      for (int i = 0; i <= n; ++i) ret[i] = vec[i].x * div(prod, vec[i].y);
      return ret.shrink(), ret;
   }
};
ODE linear_add(ODE a, ODE b) {
   if (a.size() < b.size()) swap(a, b);
   for (int i = 0; i < b.size(); ++i) a[i] = a[i] + b[i];
   return a;
}
ODE scalar_mul(ODE ode, const Z &z) {
   for (int i = 0; i < ode.size(); ++i)
      ode[i] = ode[i] * z;
   return ode;
}

ostream &operator<<(ostream &out, const Z &q) { return out << q.v; }
ostream &operator<<(ostream &out, const opoly &p) {
   if (p.empty()) return out << 0;
   out << p[0]; for (int i = 1; i < p.size(); ++i)
      out << " + " << p[i] << "x^" << i;
   return out;
}
ostream &operator<<(ostream &out, const pfrac &q) { return out << '(' << q.x << ") / (" << q.y << ')'; }
ostream &operator<<(ostream &out, const ODE &ode) {
   out << "(" << ode[0] << ")F";
   for (int i = 1; i < ode.size(); ++i) {
      out << " + (" << ode[i] << ")F^{(" << i << ")}";
   } return out << " = 0";
}
template<class T> ostream &operator<<(ostream &out, const vector<T> &v) {
   if (!v.empty()) {
      out << v.front();
      for (int i = 1; i < v.size(); ++i) out << ' ' << v[i];
   } return out;
}
ostream &operator>>(ostream &in, const Z &q) { return in >> q.v; }
template<class T> istream &operator>>(istream &is, vector<T> &v) { for (T &x : v) is >> x; return is; }

const ODE ODE_EXP = {{-Z(1)}, {Z(1)} };
const ODE ODE_LN = {{Z()}, {-Z(1)}, {Z(1), -Z(1)} };
const ODE ODE_FACT = {{Z(1)}, {-Z(1), Z(3)}, {Z(), Z(), Z(1)} };
ODE ode_power(const Z &k) { return ODE{{-Z(k)}, {Z(), Z(1)}}; }
ODE ode_bessel(const Z &k) { return ODE({{-Z(1)}, {k + Z(1)}, {Z(), Z(1)}}); }
ODE ode_pfrac(const pfrac &a) { return ode_power(1).composite(a); }
ODE ode_pFq(const vector<Z> &a, const vector<Z> &b) {
   ODE l = {{Z(1)}}, r = l;
   for (int x : a) l = linear_add(l.theta(), scalar_mul(l, x));
   for (int x : b) r = linear_add(r.theta(), scalar_mul(r, x - 1));
   ODE ret = linear_add(r.deri(), scalar_mul(l, - Z(1)));
   return ret.shrink(), ret;
} 

化简某个类可以使用成员函数 shrink()
struct Z:用来维护模 \(\text{mod}\) 意义下一个数的类,使用 struct FastMod 进行底层乘法优化。
struct opoly:继承了类 vector<Z>,用来维护短多项式。由于模数可能不是多项式模数,底层使用 \(O(n^2)\) 多项式乘法。封装了基础多项式运算 \(+-\times /\)、成员函数求导 deri() 和求值 eval(int x),以及(朴素的)多项式欧几里得 gcd(opoly, opoly)
struct pfrac:用来维护有理分式,分子与分母均为 opoly 类,分别命名为 x 和 y。封装了基础运算 \(+-\times /\)、成员函数求导 deri()
struct Q_Basis:用来维护类似线性基的结构,便于 ODE 的加法与乘法的进行。

struct ODE:大的来了!这个是三合一,可以通过调用成员函数实现:ODE 的 \(+\times \circ\),以及由 ODE 求解整式递推系数多项式、还有线性递推出前 \(n\) 项系数。一个个说。
上面提到的 ODE 的加法、乘法可以直接用运算符进行,而复合需要调用成员函数 composite(pfrac),注意这里只能复合有理分式。
由 ODE 求解整式递推,需要调用 getPRec(),得到的整式递推会被存在 ODE 类里的 _rec(这是个 vector<opoly>)中,其存放方式与【模板】整式递推中相同。但如果我们想调用 \(O(\sqrt n \log n)\) 自动机,则需要将其转为 vector<poly>,使用成员函数 trans_poly() 接收返回值即可。与 \(O(\sqrt n \log n)\) 自动机的配合只需要把需要的值求出来即可。
线性递推出前 \(n\) 项系数,使用 recur(const int& n, const poly& a0),其中 a0 存放初值,其返回 \(a_0, \dots, a_n\),并存放在一个 poly 中返回。

一些基础函数:
const ODE ODE_EXP 对应的是 \(\exp x\) 的 ODE;
const ODE ODE_LN 对应的是 \(-\ln (1-x)\) 的 ODE;
ODE ode_power(int k) 对应的是 \(x^k\) 的 ODE;
ODE ode_pFq(const vector<Z> &a, const vector<Z> &b) 对应的是 \(F\left(\left.\begin{matrix} a_1,\dots, a_p \\ b_1, \dots, b_q \end{matrix}\right\rvert x\right)\) 的 ODE;

注意:本自动机没有经过压力测试,可能出现不可预测的错误。

posted @ 2024-07-02 08:51  joke3579  阅读(329)  评论(0编辑  收藏  举报