【学习笔记】(8) 拉格朗日插值

拉格朗日插值

首先一个定理:

\(n\) 个点(横坐标不同)唯一确定一个最高 \(n-1\) 次的多项式。

那么, \(n\) 个点的点值 \((x_i,y_i)\) 可以唯一确定一个 \(n−1\) 次多项式(为了叙述方便,本文中所有“ \(k\) 次多项式”“ \(k\) 次函数”的最高次项系数可以为 0)。

拉格朗日插值就是用来求这个多项式的。

例如,我们已知四个点值 \((−1,1)(0,2)(0.5,1.375)(1,1)\) ,要求过这四个点的三次函数 \(f\)

当然,你可以直接待定系数用高斯消元解方程,但那是 \(O(n^3)\) 的,拉格朗日插值可以在 \(O(n^2)\) 内求解。

约瑟夫·拉格朗日认为这个函数可以用四个三次函数线性组合出来。

首先构造一个三次函数 \(f_1\)
,在 \(x=−1\) 的取值为 \(1\),但在其他三个点的取值为 \(0\)

类似地,构造 \(f_{2,3,4}\) 依次在每个点取值为 \(1\),在其他三个点取值为 \(0\)


画到一张图里就是这样:

那么这几个函数有啥用呢?

容易发现,\(f(x)=y_1f_1(x)+y_2f_2(x)+y_3f_3(x)+y_4f_4(x) \),把那四个点点值代入进去就可以知道。

现在问题就转化为了怎么求 \(f_{1,2,3,4}\)

推导

我们可以构造函数:

\[\Large f_1(x)=\dfrac{(x−x_2)(x−x_3)(x−x_4)}{(x_1−x_2)(x_1−x_3)(x_1−x_4)} \]

\[\Large f_2(x)=\dfrac{(x−x_1)(x−x_3)(x−x_4)}{(x_2−x_1)(x_2−x_3)(x_2−x_4)} \]

\[\Large f_3(x)=\dfrac{(x−x_1)(x−x_2)(x−x_4)}{(x_3−x_1)(x_3−x_2)(x_3−x_4)} \]

\[\Large f_4(x)=\dfrac{(x−x_1)(x−x_2)(x−x_3)}{(x_4−x_1)(x_4−x_2)(x_4−x_3)} \]

把值回代,显然符合

那对于 \(n\) 个点值求 \(n−1\) 次多项式的问题,我们先有点值 \((x_j,y_j)(1\le j\le n)\),设函数 \(f_i(1\le i\le n)\),它们是 \(n−1\) 次函数,且满足:

\[\Large f_i(x_j) = \left\{ \begin{aligned} & 1 &i=j \\ & 0 &i\neq j \end{aligned} \right. \]

则根据上面的构造函数,我们可以写成:

\[\Large f_i(x)=\prod\limits_{j=1,j\neq i }^n \dfrac{(x-x_j)}{(x_i-x_j)} \]

最终得到:

\[\Large f(x)=\sum\limits_{i = 1}^n y_i f _i(x) \]

P4781 【模板】拉格朗日插值

套上面公式直接算即可。

如果你在计算每个 \(\Large \dfrac{x−x_j}{x_i−x_j}\) 的时候都求一遍逆元,会导致复杂度变为 \(O(n^2logn)\),多带一个逆元的 \(log\)。为了让复杂度瓶颈不在逆元上,我们通常分开算分子分母,在每个函数算完后再进行有理数取模,这样的复杂度为 \(O(n^2)\)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353, N = 2e3 + 5;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, k, ans;
int x[N], y[N];
int qsm(int a, int b){
	int res = 1;
	for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
	return res;
}
int inv(int x){return qsm(x, mod - 2);}
signed main(){
	n = read(), k = read();
	for(int i = 1; i <= n; ++i) x[i] = read(), y[i] = read();
	for(int i = 1; i <= n; ++i){
		int a = y[i], b = 1;
		for(int j = 1; j <=n; ++j){
			if(i != j){
				a = a * (k - x[j]) % mod;
				b = b * (x[i] - x[j]) % mod;
			}
		}
		ans = (ans + a * inv(b) % mod + mod) % mod;
	}
	printf("%lld\n", ans);
	return 0;
}

连续点值的插值

如果已知的点值是连续点的点值,我们可以做到 \(O(n)\) 的插值。

有时候发现了一个函数是 \(n\) 次多项式,就求 \(n+1\) 个点值进去插值。为了省事,这里我们令 \(x_i=i(1\le i\le n+1)\)。注意这里 \(n\) 与上面意义不一样,是次数而不是点数。

我们有拉插公式:

\[\Large f(x)=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j=1,j\neq i}^{n +1}\dfrac{x-x_j}{x_i-x_j} \]

代入 \(x_i=i\)

\[\Large f(x)=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j=1,j\neq i}^{n +1}\dfrac{x-j}{i-j} \]

考虑怎么快速求 \(\Large \prod\limits_{j=1,j\neq i}^{n +1}\dfrac{x-j}{i-j}\)

上述式子的分子是:

\[\Large \dfrac{\prod\limits_{j=1}^{n+1}(x-j)}{x-i} \]

分母的话把 \(i−j\) 累乘拆成两段阶乘,就是:

\[\Large (−1)^{n+1−i}\cdot(i−1)!\cdot(n+1−i)! \]

于是连续点值的插值公式:

\[\Large f(x)=\sum\limits_{i=1}^{n+1}y_i\dfrac{\prod\limits_{j=1}^{n+1}(x-j)}{(x-i)\cdot(−1)^{n+1−i}\cdot(i−1)!\cdot(n+1−i)!} \]

预处理前后缀积、阶乘阶乘逆然后代这个式子的复杂度为 \(O(n)\)

CF622F The Sum of the k-th Powers

可以发现答案是 \(k+1\) 次多项式,因此代 \(k+2\) 个点进去拉插。

证明 : https://www.luogu.com.cn/blog/formkiller/cf622f-the-sum-of-the-k-th-powers-ti-xie#

本题的 \(i^k\) 可以线性筛,因此复杂度可以做到 \(O(n)\)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 1e9 + 7, N = 1e6 + 5;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, k, cnt, ans;
int v[N], inv[N], fac[N], infac[N], prime[N];
int pre[N], suf[N], f[N];
int qsm(int a, int b){
	int res = 1;
	for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
	return res;
}
void solve(int nn){
	f[1] = 1;
	for(int i = 2; i <= nn; ++i){
		if(!v[i]) v[i] = i, prime[++cnt] = i, f[i] = qsm(i, k);
		for(int j = 1; j <= cnt; ++j){
			if(prime[j] > v[i] || prime[j] > nn / i) break;
			v[i * prime[j]] = prime[j], f[i * prime[j]] = f[i] * f[prime[j]] % mod; 
		}
	}
	for(int i = 2; i <= nn; ++i) (f[i] += f[i - 1]) %= mod;
	return ;
}
signed main(){
	n = read(), k = read();
	solve(k + 2);
	if(n <= k + 2) return printf("%lld\n", f[n]), 0;
	pre[0] = suf[k + 3] = 1;
	for(int i = 1; i <= k + 2; ++i) pre[i] = pre[i - 1] * (n - i) % mod;
	for(int i = k + 2; i; --i) suf[i] = suf[i + 1] * (n - i) % mod;
	infac[0] = infac[1] = inv[0] = fac[0] = inv[1] = fac[1] = 1;
	for(int i = 2; i <= k + 2; ++i){
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	}
	for(int i = 2; i <= k + 2; ++i) infac[i] = infac[i - 1] * inv[i] % mod;
	for(int i = 1; i <= k + 2; ++i){
		int p = pre[i - 1] * suf[i + 1] % mod;
		int q = infac[i - 1] * infac[k + 2 - i] % mod;
		int mul = ((k + 2 - i) & 1) ? -1 : 1;
		ans = (ans + (q * mul + mod) % mod * p % mod * f[i] % mod) % mod;
	}
	printf("%lld\n", ans);
	return 0;
}

P4593 [TJOI2018]教科书般的亵渎

首先,认真读题不难发现若血量的区间 \([1,m_i]\) 连续,则只需要一张亵渎就可以杀死区间 \([1,m_i]\) 内所有怪物,所以 \(k = m+1\)

考虑到这点,我们就可以轻松的写出式子(保证 \(a_i\) 升序):

定义 \(a_0 = 0\),有

\[\Large ans=\sum\limits_{i=1}^{m +1}(\sum\limits_{j=1}^{n-a_{i-1}}j^{m+1} - \sum\limits_{j=i}^{m}(a_j - a_{i-1})^{m+1}) \]

然后就跟上题一样了,发现 \(\Large \sum\limits_{j=1}^{n-a_{i-1}}j^{m+1}\) 是一个 \(m+2\) 次的多项式。

由于取值是连续的,就可以优化到 \(O(m)\),对于 \(\Large \sum\limits_{j=i}^{m}(a_j - a_{i-1})^{m+1}\) 直接暴力求解就行了。

时间复杂度 \(O(m^2logm)\)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 1e9 + 7, N = 5e1 + 7;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int T, n, m, k, cnt, ans;
int v[N], inv[N], fac[N], infac[N], prime[N];
int pre[N], suf[N], f[N], a[N];
int qsm(int a, int b){
	int res = 1;
	for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
	return res;
}
void solve(int nn){
	memset(v, 0, sizeof(v));
	f[1] = 1, cnt = 0;
	for(int i = 2; i <= nn; ++i){
		if(!v[i]) v[i] = i, prime[++cnt] = i, f[i] = qsm(i, k);
		for(int j = 1; j <= cnt; ++j){
			if(prime[j] > v[i] || prime[j] > nn / i) break;
			v[i * prime[j]] = prime[j], f[i * prime[j]] = f[i] * f[prime[j]] % mod;
		}
	}
	for(int i = 2; i <= nn; ++i) (f[i] += f[i - 1]) %= mod;
	return ;
}
int lagrange(int x){
	if(x <= m + 3) return f[x];
	int res = 0; suf[m + 4] = pre[0] = 1;
	for(int i = 1; i <= m + 3; ++i) pre[i] = pre[i - 1] * (x - i) % mod;
	for(int i = m + 3; i; --i) suf[i] = suf[i + 1] * (x - i) % mod;
	for(int i = 1; i <= m + 3; ++i)
		res = (res + (((m + 3 - i) & 1) ? -1ll : 1ll) * (pre[i - 1] * suf[i + 1] % mod * infac[i - 1] % mod * infac[m + 3 - i] % mod * f[i] % mod) + mod) % mod;
	return res;
}
signed main(){
	infac[0] = infac[1] = inv[0] = fac[0] = inv[1] = fac[1] = 1;
	for(int i = 2; i <= 54; ++i){
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	}
	for(int i = 2; i <= 54; ++i) infac[i] = infac[i - 1] * inv[i] % mod;
	T = read();
	while(T--){
		ans = 0;
		n = read(), m = read();
		for(int i = 1; i <= m; ++i) a[i] = read();
		sort(a + 1, a + 1 + m); k = m + 1, solve(m + 3);
		for(int i = 1; i <= m + 1; ++i){
			ans = (ans + lagrange(n - a[i - 1])) % mod;
			for(int j = i; j <= m; ++j) ans = (ans - qsm(a[j] - a[i - 1], k) + mod) % mod;
		}
		printf("%lld\n", ans);
	}
	return 0;
}

重心拉格朗日插值

如果每次加入一个数重新求多项式的插值,每次都是 \(O(n^2)\) ,不优。

重心拉格朗日插值法可以在加入一个数后 \(O(n)\) 求出新的多项式的插值

\[\Large f(x)=\sum\limits_{i=1}^{n}y_i\prod\limits_{j=1,j\neq i}^{n}\dfrac{x-x_j}{x_i-x_j} \]

\[\Large f(x)=\sum\limits_{i=1}^{n}y_i\dfrac{\prod\limits_{j=1}^{n}(x-x_j)}{(x-x_i)\cdot\prod\limits_{j=1,j\neq i}^{n}(x_i-x_j)} \]

\(\Large g = \prod\limits_{j=1}^{n}(x-x_j), w(i) = \prod\limits_{j=1,j\neq i}^{n}(x_i-x_j)\),那么有:

\[\Large f(x)=g\cdot \sum\limits_{i=1}^{n}\dfrac{y_i}{(x-x_i)\cdot w(i)} \]

那么对于每一个新增加的插值点 ,我们可以 \(O(n)\) 的更新所有的 \(w(i)\), 求原函数仍然是 \(O(n^2)\)

参考资料: 拉格朗日插值学习笔记 拉格朗日插值学习笔记

posted @ 2023-05-20 18:41  Aurora-JC  阅读(267)  评论(0编辑  收藏  举报