多项式多点求值和插值

\(orz~fjzzq\)

多项式多点求值

给定一个多项式 \(F(x)\)
求出对于每个点 \(x_i\)\(F(x_i)\)
考虑分治

\[L(x)=\prod_{i=0}^{\frac{n}{2}}(x-x_i),R(x)=\prod_{i=\frac{n}{2}+1}^n(x-x_i) \]

那么
对于 \(0 \le i \le \frac{n}{2}\)

\[F(x_i)=(F~mod~L)(x_i) \]

对于 \(\frac{n}{2}+1 \le i \le n\)

\[F(x_i)=(F~mod~R)(x_i) \]

分治下去
可以类似线段树把 \(L(x)/R(x)\) 储存下来
分治就是在 \(dfs\) 线段树,每次只要取模就好了
复杂度为大常数 \(\Theta(nlog^2n)\)

多项式插值

给出 \(n+1\) 个点 \((x_0,y_0)...(x_n,y_n)\),求出这个多项式
根据拉格朗日插值

\[F(x)=\sum_{i=0}^{n}y_i\frac{\prod_{j=0,j\ne i}^{n}x-x_j}{\prod_{j=0,j\ne i}^{n}x_i-x_j} \]

先考虑如何对于 \(x_j\) 求出 \(\prod_{j=0,j\ne i}^{n}x_i-x_j\)

\[A(x)=\prod_{i=0}^{n}x-x_i \]

那么就是要求

\[lim_{x\rightarrow x_i}\frac{A(x)}{x-x_i} \]

根据洛必达法则

\[\prod_{j=0,j\ne i}^{n}x_i-x_j=A'(x_i) \]

所以只要求出 \(A'(x)\) 然后多点求值就好了
所以现在要求

\[\sum_{i=0}^{n}\frac{y_iA(x)}{(x-x_i)A'(x_i)} \]

这部分仍然可以分治得到

\[L(x)=\prod_{i=0}^{\frac{n}{2}}(x-x_i),R(x)=\prod_{i=\frac{n}{2}+1}^n(x-x_i) \]

那么上面的就是

\[L(x)\sum_{i=\frac{n}{2}+1}^n\frac{y_i}{A'(x_i)}\prod_{i=\frac{n}{2}+1,j\ne i}^n(x-x_i)+R(x)\sum_{i=0}^{\frac{n}{2}}\frac{y_i}{A'(x_i)}\prod_{i=0,j\ne i}^{\frac{n}{2}}(x-x_i) \]

分治即可
不过常数实在是太大了...
\(\Theta(nlog^2n)\)

Source : COGS

预处理一下单位复数根可能会快一些

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

namespace IO {
    const int maxn(1 << 21 | 1);

    char ibuf[maxn], obuf[maxn], *iS, *iT, c, *oS = obuf, *oT = obuf + maxn - 1, st[65];
    int f, tp;

    inline char Getc() {
        return iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++)) : *iS++;
    }

    template <class Int> inline void In(Int &x) {
        for (f = 1, c = Getc(); c < '0' || c > '9'; c = Getc()) f = c == '-' ? -1 : 1;
        for (x = 0; c >= '0' && c <= '9'; c = Getc()) x = (x << 1) + (x << 3) + (c ^ 48);
        x *= f;
    }

    inline void Flush() {
        fwrite(obuf, 1, oS - obuf, stdout);
        oS = obuf;
    }

    inline void Putc(char c) {
        *oS++ = c;
        if (oS == oT) Flush();
    }
    
    template <class Int> inline void Out(Int x) {
        if (!x) Putc('0');
        if (x < 0) Putc('-'), x = -x;
        while (x) st[++tp] = x % 10 + '0', x /= 10;
        while (tp) Putc(st[tp--]);
    }
}

using IO :: In;
using IO :: Out;
using IO :: Putc;
using IO :: Flush;

const int mod(998244353);
const int maxn(2e5 + 5);
 
inline int Pow(ll x, int y) {
    register ll ret = 1;
    for (; y; y >>= 1, x = x * x % mod)
        if (y & 1) ret = ret * x % mod;
    return ret;
}
 
inline void Inc(int &x, const int y) {
    if ((x += y) >= mod) x -= mod;
}
 
namespace FFT {
	int a[maxn], b[maxn], len, r[maxn], l, w[2][maxn];

    inline void Init(const int n) {
        register int i, x, y;
        for (l = 0, len = 1; len < n; len <<= 1) ++l;
        for (i = 0; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        for (i = 0; i < len; ++i) a[i] = b[i] = 0;
        w[1][0] = w[0][0] = 1, x = Pow(3, (mod - 1) / len), y = Pow(x, mod - 2);
        for (i = 1; i < len; ++i) w[0][i] = (ll)w[0][i - 1] * x % mod, w[1][i] = (ll)w[1][i - 1] * y % mod;
    }

    inline void NTT(int *p, const int opt) {
        register int i, j, k, wn, t, x, y;
        for (i = 0; i < len; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
        for (i = 1; i < len; i <<= 1)
            for (t = i << 1, j = 0; j < len; j += t)
                for (k = 0; k < i; ++k) {
                    wn = w[opt == -1][len / t * k];
                    x = p[j + k], y = (ll)wn * p[i + j + k] % mod;
                    p[j + k] = x + y >= mod ? x + y - mod : x + y;
                    p[i + j + k] = x - y < 0 ? x - y + mod : x - y;
                }
        if (opt == -1) for (wn = Pow(len, mod - 2), i = 0; i < len; ++i) p[i] = (ll)p[i] * wn % mod;
    }
    
    inline void Calc1() {
        register int i;
        NTT(a, 1), NTT(b, 1);
        for (i = 0; i < len; ++i) a[i] = 1LL * a[i] * b[i] % mod;
        NTT(a, -1);
    }
    
    inline void Calc2() {
        register int i;
        NTT(a, 1), NTT(b, 1);
        for (i = 0; i < len; ++i) a[i] = 1LL * a[i] * b[i] % mod * b[i] % mod;
        NTT(a, -1);
    }
}
 
struct Poly {
    vector <int> v;
 
    inline Poly() {
        v.resize(1);
    }
 
    inline Poly(const int d) {
        v.resize(d);
    }
 
    inline int Length() const {
        return v.size();
    }
 
    inline Poly operator +(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
        register Poly c(l3);
        for (i = 0; i < l1; ++i) c.v[i] = v[i];
        for (i = 0; i < l2; ++i) Inc(c.v[i], b.v[i]);
        return c;
    }
 
    inline Poly operator -(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
        register Poly c(l3);
        for (i = 0; i < l1; ++i) c.v[i] = v[i];
        for (i = 0; i < l2; ++i) Inc(c.v[i], mod - b.v[i]);
        return c;
    }
 
    inline void InvMul(Poly b) {
        register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
        FFT :: Init(l3);
        for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
        for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
        FFT :: Calc2();
    }
 
    inline Poly operator *(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
        register Poly c(l3);
        FFT :: Init(l3);
        for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
        for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
        FFT :: Calc1();
        for (i = 0; i < l3; ++i) c.v[i] = FFT :: a[i];
        return c;
    }
 
    inline Poly operator *(int b) const {
        register int i, l = Length();
        register Poly c(l);
        for (i = 0; i < l; ++i) c.v[i] = 1LL * v[i] * b % mod;
        return c;
    }
 
    inline int Calc(const int x) {
        register int i, ret = v[0], l = Length(), now = x;
        for (i = 1; i < l; ++i) Inc(ret, 1LL * now * v[i] % mod), now = 1LL * now * x % mod;
        return ret;
    }
};
 
inline void Inv(Poly p, Poly &q, int len) {
    if (len == 1) {
        q.v[0] = Pow(p.v[0], mod - 2);
        return;
    }
    Inv(p, q, len >> 1);
    register int i;
    p.InvMul(q);
    for (i = 0; i < len; ++i) q.v[i] = (2LL * q.v[i] + mod - FFT :: a[i]) % mod;
}
 
inline Poly Inverse(Poly a) {
    register int n = a.Length(), len;
    for (len = 1; len < n; len <<= 1);
    register Poly c(len);
    Inv(a, c, len), c.v.resize(a.Length());
    return c;
}
 
inline Poly operator %(const Poly &a, const Poly &b) {
    if (a.Length() < b.Length()) return a;
    register Poly x = a, y = b;
    register int n = a.Length(), m = b.Length(), res = n - m + 1;
    reverse(x.v.begin(), x.v.end()), reverse(y.v.begin(), y.v.end());
    x.v.resize(res), y.v.resize(res);
    x = x * Inverse(y), x.v.resize(res);
    reverse(x.v.begin(), x.v.end());
    y = a - x * b, y.v.resize(m - 1);
    return y;
}
 
Poly f[maxn], a, b;
int n, m, x[maxn], y[maxn], ans[maxn];
 
void Build(int o, int l, int r) {
    if (l == r) {
        f[o].v.resize(2), f[o].v[0] = mod - x[l], f[o].v[1] = 1;
        return;
    }
    register int mid = (l + r) >> 1;
    Build(o << 1, l, mid), Build(o << 1 | 1, mid + 1, r);
    f[o] = f[o << 1] * f[o << 1 | 1];
}

void Solve_val(Poly cur, int o, int l, int r) {
    if (r - l + 1 <= 2000) {
        for (; l <= r; ++l) ans[l] = 1LL * y[l] * Pow(cur.Calc(x[l]), mod - 2) % mod;
        return;
    }
    register int mid = (l + r) >> 1;
    Solve_val(cur % f[o << 1], o << 1, l, mid);
    Solve_val(cur % f[o << 1 | 1], o << 1 | 1, mid + 1, r);
}

void Solve_poly(Poly &cur, int o, int l, int r) {
    if (l == r) {
        cur.v[0] = ans[l];
        return;
    }
    register int mid = (l + r) >> 1;
    register Poly lp(mid - l + 1), rp(r - mid);
    Solve_poly(lp, o << 1, l, mid);
    Solve_poly(rp, o << 1 | 1, mid + 1, r);
    cur = lp * f[o << 1 | 1] + rp * f[o << 1];
}

int main() {
    register int i, len;
    for (In(n), i = 1; i <= n; ++i) In(x[i]), In(y[i]);
    Build(1, 1, n), a = f[1], len = a.Length();
    for (i = 0; i < len - 1; ++i) a.v[i] = 1LL * a.v[i + 1] * (i + 1) % mod;
    if (a.Length() > 1) a.v.pop_back();
    else a.v[0] = 0;
    b.v.resize(n), Solve_val(a, 1, 1, n), Solve_poly(b, 1, 1, n);
    for (i = 0; i < n; ++i) Out(b.v[i]), Putc(' ');
    return Flush(), 0;
}
posted @ 2018-11-29 13:51  Cyhlnj  阅读(868)  评论(2编辑  收藏  举报