拉格朗日插值如何插出系数
好久之前在 cmd's blog 看到过,这次做题遇上了,就学了一下,其实挺 easy 的。
众所周知其实是我不会证 \(n\) 个点 \((x_i,y_i)\) 可以唯一确定一个次数为 \(n-1\) 的多项式,拉格朗日插值给出了一种构造:
\[f(z)=\sum_{i=1}^{n} \dfrac{y_i\prod_{j\not=i}(z-x_j)}{\prod_{j\not=i}(x_i-x_j)}
\]
首先提出常数部分:
\[a_i=\dfrac{y_i}{\prod_{j\not=i}(x_i-x_j)}
\]
可以 \(O(n^2)\) 搞出每一个 \(a_i\)。
然后求一个多项式 \(g(z)=\prod_{i=1}^{n} (z-x_i)\)。
可以发现
\[f(z)=\sum_{i=1}^{n}a_i\dfrac{g(z)}{z-x_i}
\]
考虑如何快速搞出后面那个 \(\dfrac{g(z)}{z-x_i}\)。
设 \(h(z)=\dfrac{g(z)}{z-c}\)。
可以得到 \((z-c)h(z)=g(z)\)。两边提取系数得到
\[[z^{i-1}]h-c[z^i]h=[z^i]g\\
[z^i]h=\dfrac{[z^i]g-[z^{i-1}]h}{-c}
\]
递推即可。
给出 模板题 通过代码:
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
#define mod 998244353
inline int qpow(int n, int k) {
int res = 1;
for(; k; k >>= 1, n = 1ll * n * n % mod)
if(k & 1) res = 1ll * n * res % mod;
return res;
}
vector <int> lagrange(const vector <int> &x, const vector <int> &y) {
assert(x.size() == y.size());
int n = x.size();
vector <int> a(n, 0), b(n + 1, 0), c(n + 1, 0), f(n, 0);
for(int i = 0; i < n; ++i) {
int A = 1;
for(int j = 0; j < n; ++j) if(i != j)
A = 1ll * A * (x[i] - x[j] + mod) % mod;
a[i] = 1ll * qpow(A, mod - 2) * y[i] % mod;
}
b[0] = 1;
for(int i = 0; i < n; ++i) {
for(int j = i + 1; j >= 1; --j)
b[j] = (1ll * b[j] * (mod - x[i]) + b[j - 1]) % mod;
b[0] = 1ll * b[0] * (mod - x[i]) % mod;
}
for(int i = 0; i < n; ++i) {
int iv = qpow(mod - x[i], mod - 2);
if(!iv) {
for(int j = 0; j < n; ++j) c[j] = b[j + 1];
} else {
c[0] = 1ll * b[0] * iv % mod;
for(int j = 1; j <= n; ++j)
c[j] = 1ll * (b[j] + mod - c[j - 1]) * iv % mod;
}
for(int j = 0; j < n; ++j)
f[j] = (f[j] + 1ll * a[i] * c[j] % mod) % mod;
}
return f;
}
inline int calc(const vector <int> &f, int x) {
int res = 0;
for(int i = f.size() - 1; i >= 0; --i) res = (1ll * res * x + f[i]) % mod;
return res;
}
signed main() {
int n = read(), k = read();
vector <int> x(n), y(n);
for(int i = 0; i < n; ++i) x[i] = read(), y[i] = read();
vector <int> f = lagrange(x, y);
cout << calc(f, k) << '\n';
}
路漫漫其修远兮,吾将上下而求索