关于三次多项式复合的一个注记
首先根据熟知的变换, 复合 \(f(ax^3+bx^2+cx+d)\) 的问题的困难内核在于 \(f(x^3+cx)\), 在域上, 只要解决某个 \(c\neq 0\) 的情况, 就解决了一般的情况.
取 \(c = -3\), 我们有
\[x^3 - 3x = (x^3 + x^{-3}) \circ (x + x^{-1})^{\langle-1\rangle}.
\]
于是问题的难点转换成了计算复合 \(f(x+x^{-1})\), 由于
\[f\left(\frac{x+x^{-1}}2\right) = f\left( \frac{x+1}{x-1} \right) \circ x^2
\circ \frac{x+1}{x-1}, \]
整个问题可以在 \(O(n\log n)\) 时间内解决.
之前本来想要出题, 但是做过一些实验发现有几十倍的常数, 很难击败 \(O(n\log^2 n)\) 做法.
由于某些原因, 今天公开一下上面的注记.
这是本来想要出的交互题的代码:
若干注记
Remark 1 下一个值得研究的可能是四次多项式是否有通用方法. 我们现在所有的思路都是基于将多项式分解成一系列 \(x^k, (ax+b)/(cx+d)\) 或者它们的逆的复合来解决的. 上个世纪, Joseph Ritt 证明了 \(\mathbb C[x]\) 上多项式的复合分解的一个基本的刻画: Ritt 分解定理, 也许可以利用他的技术进一步分析高次多项式用一些已有的零件 (未必是多项式) 能够分解任意多项式.
Remark 2 对于更高次多项式, 如果我们的分解过于简单, 是一定不能解决一般的五次以上多项式的, 因为前面提到的零件给出的复合分解一定给出 \(f(x) = 0\) 的一种根式的复合解, 太简单的做法会与 Abel–Ruffini 定理矛盾.
代码
// polycomp.h
namespace __athekatelan__ {
typedef unsigned long long u64;
const int mod = 998244353;
extern long long op_cnt, mul_sum;
class field {
int val;
public:
field(int val = 0) : val(val) {}
inline field operator+(const field& rhs) const {
++op_cnt;
int nval = val + rhs.val;
if (nval >= mod) nval -= mod;
return nval;
}
inline field operator-(const field& rhs) const {
++op_cnt;
int nval = val - rhs.val;
if (nval < 0) nval += mod;
return nval;
}
inline field operator*(const field& rhs) const {
++op_cnt;
return val * (u64)rhs.val % mod;
}
// Participant should not access this
inline int __get() const {
return val;
}
};
std::vector<field> mul(const std::vector<field>& a, const std::vector<field>& b);
}
using __athekatelan__::mod;
using __athekatelan__::field;
// std.cpp
int mpow(int x, int k) {
if (k == 0) return 1;
int ret = mpow(x * (u64)x % mod, k >> 1);
if (k & 1) ret = ret * (u64)x % mod;
return ret;
}
int inv(int a) { return mpow(a, mod - 2); }
int fac[1 << 21], ifac[1 << 21];
void prepare(int n) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * (u64)i % mod;
ifac[n] = inv(fac[n]);
for (int i = n; i; --i) ifac[i - 1] = ifac[i] * (u64)i % mod;
}
vector<field> taylor(vector<field> f, int a) {
int n = f.size() - 1;
for (int i = 0; i <= n; ++i) f[i] = f[i] * fac[i];
reverse(f.begin(), f.end());
vector<field> op(n + 1);
int pw = 1;
for (int i = 0; i <= n; ++i) {
op[i] = pw * (u64)ifac[i] % mod;
pw = pw * (u64)a % mod;
}
f = mul(f, op);
f.resize(n + 1);
reverse(f.begin(), f.end());
for (int i = 0; i <= n; ++i) f[i] = f[i] * ifac[i];
return f;
}
vector<field> comp_q(vector<field> f, int q) {
int pw = 1;
for (int i = 0; i != f.size(); ++i) {
f[i] = f[i] * pw;
pw = pw * (u64)q % mod;
}
return f;
}
vector<field> comp_x(vector<field> f, int k) {
int n = f.size() - 1;
vector<field> ret(n * k + 1);
for (int i = 0; i <= n; ++i) ret[i * k] = f[i];
return ret;
}
vector<field> mobius(vector<field> f) {
reverse(f.begin(), f.end());
f = taylor(comp_q(f, mod - 2), mod - inv(2));
reverse(f.begin(), f.end());
return taylor(f, 1);
}
vector<field> solve_3(vector<field> f) {
int n = f.size() - 1;
f = mobius(comp_x(mobius(comp_q(f, 2)), 2));
int val = mpow(4, mod - 1 - n);
for (int i = 0; i <= n * 2; ++i) f[i] = f[i] * val;
f = comp_x(f, 3);
f = mobius(f);
vector<field> g(n * 3 + 1);
val = mpow(8, mod - 1 - n);
for (int i = 0; i <= n * 3; ++i) g[i] = f[i * 2] * val;;
return comp_q(mobius(g), inv(2));
}
vector<field> solve_a(vector<field> f, int c) {
int n = f.size() - 1;
if (c == 0) return comp_x(f, 3);
int l2 = c * (u64)inv(mod - 3) % mod;
vector<field> even(n + 1), odd(n + 1);
{
int q = l2 * (u64)l2 % mod * l2 % mod, pw = 1;
for (int i = 0; i <= n; i += 2) {
even[i] = f[i] * pw;
if (i + 1 <= n)
odd[i + 1] = f[i + 1] * (pw * (u64)l2 % mod);
pw = pw * (u64)q % mod;
}
}
even = solve_3(even); odd = solve_3(odd);
int q = inv(l2), pw = 1;
for (int i = 0; i <= n * 3; i += 2) {
even[i] = even[i] * pw;
if (i + 1 <= n * 3)
even[i + 1] = odd[i + 1] * pw;
pw = pw * (u64)q % mod;
}
return even;
}
vector<field> comp(vector<field> f, vector<int> g) {
int n = f.size() - 1, m = g.size() - 1;
prepare(n * 6);
{
int a = g.back(), ia = inv(a);
for (int i = 0; i <= m; ++i) g[i] = g[i] * (u64)ia % mod;
f = comp_q(f, a);
}
if (m == 1) return taylor(f, g[0]);
else if (m == 2) {
int u = (g[0] + (mod - 1) / 4ull * g[1] % mod * g[1]) % mod;
f = taylor(f, u);
vector<field> db(n * 2 + 1);
for (int i = 0; i <= n; ++i) db[i * 2] = f[i];
return taylor(db, (mod - (mod - 1) / 2ull) * g[1] % mod);
} else {
int u = inv(3) * (u64)g[2] % mod;
int r = (mod - u) % mod, r2 = r * (u64)r % mod, r3 = r2 * (u64)r % mod;
g[0] = (g[0] + g[1] * (u64)r + g[2] * (u64)r2 + r3) % mod;
g[1] = (g[1] + g[2] * 2ull * r + 3ull * r2) % mod;
g[2] = (g[2] + 3ull * r) % mod;
return taylor(solve_a(taylor(f, g[0]), g[1]), u);
}
}